Beispiel #1
0
class Data:
    def __init__(self):
        self.MAX_SENTENCE_LENGTH = 250
        self.MAX_WORD_LENGTH = -1
        self.number_normalized = True
        self.norm_word_emb = False
        self.norm_char_emb = False
        self.norm_trans_emb = False
        self.word_alphabet = Alphabet('word')
        self.char_alphabet = Alphabet('character')

        self.translation_alphabet = Alphabet('translation')
        self.translation_id_format = {}

        self.feature_name = []
        self.feature_alphabets = []
        self.feature_num = len(self.feature_alphabets)
        self.feat_config = None

        self.label_alphabet = Alphabet('label', True)
        self.tagScheme = "NoSeg"  ## BMES/BIO

        self.seg = True

        ### I/O
        self.train_dir = None
        self.dev_dir = None
        self.test_dir = None
        self.raw_dir = None

        self.trans_dir = None

        self.decode_dir = None
        self.dset_dir = None  ## data vocabulary related file
        self.model_dir = None  ## model save  file
        self.load_model_dir = None  ## model load file

        self.word_emb_dir = None
        self.char_emb_dir = None
        self.trans_embed_dir = None

        self.feature_emb_dirs = []

        self.train_texts = []
        self.dev_texts = []
        self.test_texts = []
        self.raw_texts = []

        self.train_Ids = []
        self.dev_Ids = []
        self.test_Ids = []
        self.raw_Ids = []

        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None
        self.pretrain_trans_embedding = None
        self.pretrain_feature_embeddings = []

        self.label_size = 0
        self.word_alphabet_size = 0
        self.char_alphabet_size = 0
        self.label_alphabet_size = 0
        self.trans_alphabet_size = 0

        self.feature_alphabet_sizes = []
        self.feature_emb_dims = []
        self.norm_feature_embs = []
        self.word_emb_dim = 50
        self.char_emb_dim = 30
        self.trans_emb_dim = 100

        ###Networks
        self.word_feature_extractor = "LSTM"  ## "LSTM"/"CNN"/"GRU"/
        self.use_char = True
        self.char_seq_feature = "CNN"  ## "LSTM"/"CNN"/"GRU"/None
        self.use_trans = True
        self.use_crf = True
        self.nbest = None

        ## Training
        self.average_batch_loss = False
        self.optimizer = "SGD"  ## "SGD"/"AdaGrad"/"AdaDelta"/"RMSProp"/"Adam"
        self.status = "train"
        ### Hyperparameters
        self.HP_cnn_layer = 4
        self.HP_iteration = 100
        self.HP_batch_size = 10
        self.HP_char_hidden_dim = 50
        self.HP_trans_hidden_dim = 50
        self.HP_hidden_dim = 200
        self.HP_dropout = 0.5
        self.HP_lstm_layer = 1
        self.HP_bilstm = True

        self.HP_gpu = False
        self.HP_lr = 0.015
        self.HP_lr_decay = 0.05
        self.HP_clip = None
        self.HP_momentum = 0
        self.HP_l2 = 1e-8

    def show_data_summary(self):
        print("++" * 50)
        print("DATA SUMMARY START:")
        print(" I/O:")
        print("     Tag          scheme: %s" % (self.tagScheme))
        print("     MAX SENTENCE LENGTH: %s" % (self.MAX_SENTENCE_LENGTH))
        print("     MAX   WORD   LENGTH: %s" % (self.MAX_WORD_LENGTH))
        print("     Number   normalized: %s" % (self.number_normalized))
        print("     Word  alphabet size: %s" % (self.word_alphabet_size))
        print("     Char  alphabet size: %s" % (self.char_alphabet_size))
        print("     Label alphabet size: %s" % (self.label_alphabet_size))
        print("     Trans alphabet size: %s" % (self.trans_alphabet_size))
        print("     Word embedding  dir: %s" % (self.word_emb_dir))
        print("     Char embedding  dir: %s" % (self.char_emb_dir))
        print("     Tran embedding  dir: %s" % (self.trans_embed_dir))
        print("     Word embedding size: %s" % (self.word_emb_dim))
        print("     Char embedding size: %s" % (self.char_emb_dim))
        print("     Tran embedding size: %s" % (self.trans_emb_dim))
        print("     Norm   word     emb: %s" % (self.norm_word_emb))
        print("     Norm   char     emb: %s" % (self.norm_char_emb))
        print("     Norm   tran     emb: %s" % (self.norm_trans_emb))
        print("     Train  file directory: %s" % (self.train_dir))
        print("     Dev    file directory: %s" % (self.dev_dir))
        print("     Test   file directory: %s" % (self.test_dir))
        print("     Raw    file directory: %s" % (self.raw_dir))
        print("     Dset   file directory: %s" % (self.dset_dir))
        print("     Model  file directory: %s" % (self.model_dir))
        print("     Loadmodel   directory: %s" % (self.load_model_dir))
        print("     Decode file directory: %s" % (self.decode_dir))
        print("     Train instance number: %s" % (len(self.train_texts)))
        print("     Dev   instance number: %s" % (len(self.dev_texts)))
        print("     Test  instance number: %s" % (len(self.test_texts)))
        print("     Raw   instance number: %s" % (len(self.raw_texts)))
        print("     FEATURE num: %s" % (self.feature_num))
        for idx in range(self.feature_num):
            print("         Fe: %s  alphabet  size: %s" %
                  (self.feature_alphabets[idx].name,
                   self.feature_alphabet_sizes[idx]))
            print(
                "         Fe: %s  embedding  dir: %s" %
                (self.feature_alphabets[idx].name, self.feature_emb_dirs[idx]))
            print(
                "         Fe: %s  embedding size: %s" %
                (self.feature_alphabets[idx].name, self.feature_emb_dims[idx]))
            print("         Fe: %s  norm       emb: %s" %
                  (self.feature_alphabets[idx].name,
                   self.norm_feature_embs[idx]))
        print(" " + "++" * 20)
        print(" Model Network:")
        print("     Model        use_crf: %s" % (self.use_crf))
        print("     Model word extractor: %s" % (self.word_feature_extractor))
        print("     Model       use_char: %s" % (self.use_char))
        if self.use_char:
            print("     Model char_seq_feature: %s" % (self.char_seq_feature))
            print("     Model char_hidden_dim: %s" % (self.HP_char_hidden_dim))
        if self.use_trans:
            print("     Model trans_hidden_dim: %s" %
                  (self.HP_trans_hidden_dim))
        print(" " + "++" * 20)
        print(" Training:")
        print("     Optimizer: %s" % (self.optimizer))
        print("     Iteration: %s" % (self.HP_iteration))
        print("     BatchSize: %s" % (self.HP_batch_size))
        print("     Average  batch   loss: %s" % (self.average_batch_loss))

        print(" " + "++" * 20)
        print(" Hyperparameters:")

        print("     Hyper              lr: %s" % (self.HP_lr))
        print("     Hyper        lr_decay: %s" % (self.HP_lr_decay))
        print("     Hyper         HP_clip: %s" % (self.HP_clip))
        print("     Hyper        momentum: %s" % (self.HP_momentum))
        print("     Hyper              l2: %s" % (self.HP_l2))
        print("     Hyper      hidden_dim: %s" % (self.HP_hidden_dim))
        print("     Hyper         dropout: %s" % (self.HP_dropout))
        print("     Hyper      lstm_layer: %s" % (self.HP_lstm_layer))
        print("     Hyper          bilstm: %s" % (self.HP_bilstm))
        print("     Hyper             GPU: %s" % (self.HP_gpu))
        print("DATA SUMMARY END.")
        print("++" * 50)
        sys.stdout.flush()

    def initial_feature_alphabets(self):
        items = open(self.train_dir, 'r').readline().strip('\n').split()
        print(items)
        total_column = len(items)
        if total_column > 2:
            for idx in range(1, total_column - 1):
                feature_prefix = items[idx].split(']', 1)[0] + "]"
                print("feature_prefix:{}".format(feature_prefix))
                self.feature_alphabets.append(Alphabet(feature_prefix))
                self.feature_name.append(feature_prefix)
                print("Find feature: ", feature_prefix)
        self.feature_num = len(self.feature_alphabets)
        self.pretrain_feature_embeddings = [None] * self.feature_num
        self.feature_emb_dims = [20] * self.feature_num
        self.feature_emb_dirs = [None] * self.feature_num
        self.norm_feature_embs = [False] * self.feature_num
        self.feature_alphabet_sizes = [0] * self.feature_num
        if self.feat_config:
            for idx in range(self.feature_num):
                if self.feature_name[idx] in self.feat_config:
                    self.feature_emb_dims[idx] = self.feat_config[
                        self.feature_name[idx]]['emb_size']
                    self.feature_emb_dirs[idx] = self.feat_config[
                        self.feature_name[idx]]['emb_dir']
                    self.norm_feature_embs[idx] = self.feat_config[
                        self.feature_name[idx]]['emb_norm']
        # exit(0)

    def build_alphabet(self, input_file):
        print("Build alphabet......")
        in_lines = open(input_file, 'r').readlines()
        for line in in_lines:
            if len(line) > 2:
                pairs = line.strip().split()
                word = pairs[0].decode('utf-8')
                if self.number_normalized:
                    word = normalize_word(word)
                label = pairs[-1]
                self.label_alphabet.add(label)
                self.word_alphabet.add(word)
                ## build feature alphabet
                for idx in range(self.feature_num):
                    feat_idx = pairs[idx + 1].split(']', 1)[-1]
                    self.feature_alphabets[idx].add(feat_idx)
                for char in word:
                    self.char_alphabet.add(char)
        self.word_alphabet_size = self.word_alphabet.size()
        self.char_alphabet_size = self.char_alphabet.size()
        self.label_alphabet_size = self.label_alphabet.size()
        for idx in range(self.feature_num):
            self.feature_alphabet_sizes[idx] = self.feature_alphabets[
                idx].size()
        startS = False
        startB = False
        for label, _ in self.label_alphabet.iteritems():
            if "S-" in label.upper():
                startS = True
            elif "B-" in label.upper():
                startB = True
        if startB:
            if startS:
                self.tagScheme = "BMES"
            else:
                self.tagScheme = "BIO"

    def fix_alphabet(self):
        self.word_alphabet.close()
        self.char_alphabet.close()
        self.label_alphabet.close()
        self.translation_alphabet.close()
        for idx in range(self.feature_num):
            self.feature_alphabets[idx].close()

    def build_pretrain_emb(self):
        if self.word_emb_dir:
            print("Load pretrained word embedding, norm: %s, dir: %s" %
                  (self.norm_word_emb, self.word_emb_dir))
            self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(
                self.word_emb_dir, self.word_alphabet, self.word_emb_dim,
                self.norm_word_emb)
        if self.char_emb_dir:
            print("Load pretrained char embedding, norm: %s, dir: %s" %
                  (self.norm_char_emb, self.char_emb_dir))
            self.pretrain_char_embedding, self.char_emb_dim = build_pretrain_embedding(
                self.char_emb_dir, self.char_alphabet, self.char_emb_dim,
                self.norm_char_emb)
        if self.trans_embed_dir:
            print("Load pretrained trans embedding, norm: %s, dir: %s" %
                  (self.norm_trans_emb, self.trans_embed_dir))
            self.pretrain_trans_embedding, self.trans_emb_dim = build_pretrain_embedding(
                self.trans_embed_dir, self.translation_alphabet,
                self.trans_emb_dim, self.norm_trans_emb)

        for idx in range(self.feature_num):
            if self.feature_emb_dirs[idx]:
                print(
                    "Load pretrained feature %s embedding:, norm: %s, dir: %s"
                    % (self.feature_name[idx], self.norm_feature_embs[idx],
                       self.feature_emb_dirs[idx]))
                self.pretrain_feature_embeddings[idx], self.feature_emb_dims[
                    idx] = build_pretrain_embedding(
                        self.feature_emb_dirs[idx],
                        self.feature_alphabets[idx],
                        self.feature_emb_dims[idx],
                        self.norm_feature_embs[idx])

    def generate_instance(self, name):
        self.fix_alphabet()
        if name == "train":
            self.train_texts, self.train_Ids = read_instance(
                self.train_dir, self.word_alphabet, self.char_alphabet,
                self.feature_alphabets, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH,
                self.translation_id_format)
        elif name == "dev":
            self.dev_texts, self.dev_Ids = read_instance(
                self.dev_dir, self.word_alphabet, self.char_alphabet,
                self.feature_alphabets, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH,
                self.translation_id_format)
        elif name == "test":
            self.test_texts, self.test_Ids = read_instance(
                self.test_dir, self.word_alphabet, self.char_alphabet,
                self.feature_alphabets, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH,
                self.translation_id_format)
        elif name == "raw":
            self.raw_texts, self.raw_Ids = read_instance(
                self.raw_dir, self.word_alphabet, self.char_alphabet,
                self.feature_alphabets, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH,
                self.translation_id_format)
        else:
            print(
                "Error: you can only generate train/dev/test instance! Illegal input:%s"
                % (name))

    def write_decoded_results(self, predict_results, name):
        fout = open(self.decode_dir, 'w')
        sent_num = len(predict_results)
        content_list = []
        if name == 'raw':
            content_list = self.raw_texts
        elif name == 'test':
            content_list = self.test_texts
        elif name == 'dev':
            content_list = self.dev_texts
        elif name == 'train':
            content_list = self.train_texts
        else:
            print(
                "Error: illegal name during writing predict result, name should be within train/dev/test/raw !"
            )
        assert (sent_num == len(content_list))
        for idx in range(sent_num):
            sent_length = len(predict_results[idx])
            for idy in range(sent_length):
                ## content_list[idx] is a list with [word, char, label]
                fout.write(content_list[idx][0][idy].encode('utf-8') + " " +
                           predict_results[idx][idy] + '\n')
            fout.write('\n')
        fout.close()
        print("Predict %s result has been written into file. %s" %
              (name, self.decode_dir))

    def load(self, data_file):
        f = open(data_file, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)

    def save(self, save_file):
        f = open(save_file, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()

    def write_nbest_decoded_results(self, predict_results, pred_scores, name):
        ## predict_results : [whole_sent_num, nbest, each_sent_length]
        ## pred_scores: [whole_sent_num, nbest]
        fout = open(self.decode_dir, 'w')
        sent_num = len(predict_results)
        content_list = []
        if name == 'raw':
            content_list = self.raw_texts
        elif name == 'test':
            content_list = self.test_texts
        elif name == 'dev':
            content_list = self.dev_texts
        elif name == 'train':
            content_list = self.train_texts
        else:
            print(
                "Error: illegal name during writing predict result, name should be within train/dev/test/raw !"
            )
        assert (sent_num == len(content_list))
        assert (sent_num == len(pred_scores))
        for idx in range(sent_num):
            sent_length = len(predict_results[idx][0])
            nbest = len(predict_results[idx])
            score_string = "# "
            for idz in range(nbest):
                score_string += format(pred_scores[idx][idz], '.4f') + " "
            fout.write(score_string.strip() + "\n")

            for idy in range(sent_length):
                label_string = content_list[idx][0][idy].encode('utf-8') + " "
                for idz in range(nbest):
                    label_string += predict_results[idx][idz][idy] + " "
                label_string = label_string.strip() + "\n"
                fout.write(label_string)
            fout.write('\n')
        fout.close()
        print("Predict %s %s-best result has been written into file. %s" %
              (name, nbest, self.decode_dir))

    def read_config(self, config_file):
        config = config_file_to_dict(config_file)
        ## read data:
        the_item = 'train_dir'
        if the_item in config:
            self.train_dir = config[the_item]
        the_item = 'dev_dir'
        if the_item in config:
            self.dev_dir = config[the_item]
        the_item = 'test_dir'
        if the_item in config:
            self.test_dir = config[the_item]

        the_item = 'trans_dir'
        if the_item in config:
            self.trans_dir = config[the_item]

        the_item = 'raw_dir'
        if the_item in config:
            self.raw_dir = config[the_item]
        the_item = 'decode_dir'
        if the_item in config:
            self.decode_dir = config[the_item]
        the_item = 'dset_dir'
        if the_item in config:
            self.dset_dir = config[the_item]
        the_item = 'model_dir'
        if the_item in config:
            self.model_dir = config[the_item]
        the_item = 'load_model_dir'
        if the_item in config:
            self.load_model_dir = config[the_item]

        the_item = 'word_emb_dir'
        if the_item in config:
            self.word_emb_dir = config[the_item]
        the_item = 'char_emb_dir'
        if the_item in config:
            self.char_emb_dir = config[the_item]
        the_item = 'trans_embed_dir'
        if the_item in config:
            self.trans_embed_dir = config[the_item]

        the_item = 'MAX_SENTENCE_LENGTH'
        if the_item in config:
            self.MAX_SENTENCE_LENGTH = int(config[the_item])
        the_item = 'MAX_WORD_LENGTH'
        if the_item in config:
            self.MAX_WORD_LENGTH = int(config[the_item])

        the_item = 'norm_word_emb'
        if the_item in config:
            self.norm_word_emb = str2bool(config[the_item])
        the_item = 'norm_char_emb'
        if the_item in config:
            self.norm_char_emb = str2bool(config[the_item])
        the_item = 'number_normalized'
        if the_item in config:
            self.number_normalized = str2bool(config[the_item])

        the_item = 'seg'
        if the_item in config:
            self.seg = str2bool(config[the_item])
        the_item = 'word_emb_dim'
        if the_item in config:
            self.word_emb_dim = int(config[the_item])
        the_item = 'char_emb_dim'
        if the_item in config:
            self.char_emb_dim = int(config[the_item])
        the_item = 'trans_emb_dim'
        if the_item in config:
            self.trans_emb_dim = int(config[the_item])

        ## read network:
        the_item = 'use_crf'
        if the_item in config:
            self.use_crf = str2bool(config[the_item])
        the_item = 'use_char'
        if the_item in config:
            self.use_char = str2bool(config[the_item])
        the_item = 'use_trans'
        if the_item in config:
            self.use_trans = str2bool(config[the_item])
        the_item = 'word_seq_feature'
        if the_item in config:
            self.word_feature_extractor = config[the_item]
        the_item = 'char_seq_feature'
        if the_item in config:
            self.char_seq_feature = config[the_item]
        the_item = 'nbest'
        if the_item in config:
            self.nbest = int(config[the_item])

        the_item = 'feature'
        if the_item in config:
            self.feat_config = config[the_item]  ## feat_config is a dict

        ## read training setting:
        the_item = 'optimizer'
        if the_item in config:
            self.optimizer = config[the_item]
        the_item = 'ave_batch_loss'
        if the_item in config:
            self.average_batch_loss = str2bool(config[the_item])
        the_item = 'status'
        if the_item in config:
            self.status = config[the_item]

        ## read Hyperparameters:
        the_item = 'cnn_layer'
        if the_item in config:
            self.HP_cnn_layer = int(config[the_item])
        the_item = 'iteration'
        if the_item in config:
            self.HP_iteration = int(config[the_item])
        the_item = 'batch_size'
        if the_item in config:
            self.HP_batch_size = int(config[the_item])

        the_item = 'char_hidden_dim'
        if the_item in config:
            self.HP_char_hidden_dim = int(config[the_item])

        the_item = 'trans_hidden_dim'
        if the_item in config:
            self.HP_trans_hidden_dim = int(config[the_item])

        the_item = 'hidden_dim'
        if the_item in config:
            self.HP_hidden_dim = int(config[the_item])
        the_item = 'dropout'
        if the_item in config:
            self.HP_dropout = float(config[the_item])
        the_item = 'lstm_layer'
        if the_item in config:
            self.HP_lstm_layer = int(config[the_item])
        the_item = 'bilstm'
        if the_item in config:
            self.HP_bilstm = str2bool(config[the_item])

        the_item = 'gpu'
        if the_item in config:
            self.HP_gpu = str2bool(config[the_item])
        the_item = 'learning_rate'
        if the_item in config:
            self.HP_lr = float(config[the_item])
        the_item = 'lr_decay'
        if the_item in config:
            self.HP_lr_decay = float(config[the_item])
        the_item = 'clip'
        if the_item in config:
            self.HP_clip = float(config[the_item])
        the_item = 'momentum'
        if the_item in config:
            self.HP_momentum = float(config[the_item])
        the_item = 'l2'
        if the_item in config:
            self.HP_l2 = float(config[the_item])

    def build_translation_alphabet(self, trans_path):
        print("Creating translation alphabet......")
        with codecs.open(trans_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                if len(line.strip().split(":")) == 2:
                    temp = line.strip().split(":", 1)
                    words = temp[1].split()
                    for word in words:
                        self.translation_alphabet.add(word.strip())
        self.trans_alphabet_size = self.translation_alphabet.size()

    def build_translation_dict(self, trans_path):
        print("Creating Id to Id translation dictionary......")
        translation_id_format_temp = {}
        with codecs.open(trans_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                ids = []
                if len(line.strip().split(":")) == 2:
                    temp = line.strip().split(":")
                    word_id = self.word_alphabet.get_index(temp[0].strip())
                    translations = temp[1].split()
                    for translation in translations:
                        ids.append(
                            self.translation_alphabet.get_index(
                                translation.strip()))
                    translation_id_format_temp[word_id] = ids

        for word in self.word_alphabet.instances:
            if self.word_alphabet.get_index(
                    word) in translation_id_format_temp.keys():
                self.translation_id_format[self.word_alphabet.get_index(
                    word)] = translation_id_format_temp[
                        self.word_alphabet.get_index(word)]
            else:
                self.translation_id_format[self.word_alphabet.get_index(
                    word)] = [0]
Beispiel #2
0
class Data:
    def __init__(self):
        self.MAX_SENTENCE_LENGTH = 250
        self.MAX_WORD_LENGTH = -1
        self.number_normalized = True
        self.norm_word_emb = True
        self.norm_biword_emb = True
        self.norm_gaz_emb = False
        self.word_alphabet = Alphabet('word')
        self.biword_alphabet = Alphabet('biword')
        self.char_alphabet = Alphabet('character')
        self.label_alphabet = Alphabet('label', True)
        #self.simi_alphabet = Alphabet('simi')  #添加计算相似度词语的信息
        self.gaz_lower = False
        self.gaz = Gazetteer(self.gaz_lower)
        self.gaz_alphabet = Alphabet('gaz')
        self.gaz_count = {}
        self.gaz_split = {}
        self.biword_count = {}

        self.HP_fix_gaz_emb = False
        self.HP_use_gaz = True
        self.HP_use_count = False

        self.tagScheme = "NoSeg"
        self.char_features = "LSTM"

        self.train_texts = []
        self.dev_texts = []
        self.test_texts = []
        self.raw_texts = []

        self.train_Ids = []
        self.dev_Ids = []
        self.test_Ids = []
        self.raw_Ids = []

        self.train_split_index = []
        self.dev_split_index = []

        self.use_bigram = True
        self.word_emb_dim = 200
        self.biword_emb_dim = 200
        self.char_emb_dim = 30
        self.gaz_emb_dim = 200
        self.gaz_dropout = 0.5
        self.pretrain_word_embedding = None
        self.pretrain_biword_embedding = None
        self.pretrain_gaz_embedding = None
        self.label_size = 0
        self.word_alphabet_size = 0
        self.biword_alphabet_size = 0
        self.char_alphabet_size = 0
        self.label_alphabet_size = 0

        #设置词典相似度相关的参数
        self.simi_dic_emb = None  #设置相似度的嵌入值
        self.simi_dic_dim = 10  #设置相似度向量的纬度
        self.use_dictionary = False  # 设置当前是否使用词典
        self.simi_list = []  #存储当前的每个字对应的相似度值
        # self.use_gazcount = 'True'

        ### hyperparameters
        self.HP_iteration = 60
        self.HP_batch_size = 10
        self.HP_char_hidden_dim = 50
        self.HP_hidden_dim = 128
        self.HP_dropout = 0.5
        self.HP_lstm_layer = 1
        self.HP_bilstm = True
        self.HP_use_char = False
        self.HP_gpu = True
        self.HP_lr = 0.015
        self.HP_lr_decay = 0.05
        self.HP_clip = 5.0
        self.HP_momentum = 0

        self.HP_num_layer = 4

    def show_data_summary(self):
        print("DATA SUMMARY START:")
        print("     Tag          scheme: %s" % (self.tagScheme))
        print("     MAX SENTENCE LENGTH: %s" % (self.MAX_SENTENCE_LENGTH))
        print("     MAX   WORD   LENGTH: %s" % (self.MAX_WORD_LENGTH))
        print("     Number   normalized: %s" % (self.number_normalized))
        print("     Use          bigram: %s" % (self.use_bigram))
        print("     Word  alphabet size: %s" % (self.word_alphabet_size))
        print("     Biword alphabet size: %s" % (self.biword_alphabet_size))
        print("     Char  alphabet size: %s" % (self.char_alphabet_size))
        print("     Gaz   alphabet size: %s" % (self.gaz_alphabet.size()))
        print("     Label alphabet size: %s" % (self.label_alphabet_size))
        print("     Word embedding size: %s" % (self.word_emb_dim))
        print("     Biword embedding size: %s" % (self.biword_emb_dim))
        print("     Char embedding size: %s" % (self.char_emb_dim))
        print("     Gaz embedding size: %s" % (self.gaz_emb_dim))
        print("     Norm     word   emb: %s" % (self.norm_word_emb))
        print("     Norm     biword emb: %s" % (self.norm_biword_emb))
        print("     Norm     gaz    emb: %s" % (self.norm_gaz_emb))
        print("     Norm   gaz  dropout: %s" % (self.gaz_dropout))
        print("     Train instance number: %s" % (len(self.train_texts)))
        print("     Dev   instance number: %s" % (len(self.dev_texts)))
        print("     Test  instance number: %s" % (len(self.test_texts)))
        print("     Raw   instance number: %s" % (len(self.raw_texts)))
        print("     Hyperpara  iteration: %s" % (self.HP_iteration))
        print("     Hyperpara  batch size: %s" % (self.HP_batch_size))
        print("     Hyperpara          lr: %s" % (self.HP_lr))
        print("     Hyperpara    lr_decay: %s" % (self.HP_lr_decay))
        print("     Hyperpara     HP_clip: %s" % (self.HP_clip))
        print("     Hyperpara    momentum: %s" % (self.HP_momentum))
        print("     Hyperpara  hidden_dim: %s" % (self.HP_hidden_dim))
        print("     Hyperpara     dropout: %s" % (self.HP_dropout))
        print("     Hyperpara  lstm_layer: %s" % (self.HP_lstm_layer))
        print("     Hyperpara      bilstm: %s" % (self.HP_bilstm))
        print("     Hyperpara         GPU: %s" % (self.HP_gpu))
        print("     Hyperpara     use_gaz: %s" % (self.HP_use_gaz))
        print("     Hyperpara fix gaz emb: %s" % (self.HP_fix_gaz_emb))
        print("     Hyperpara    use_char: %s" % (self.HP_use_char))
        if self.HP_use_char:
            print("             Char_features: %s" % (self.char_features))
        print("DATA SUMMARY END.")
        sys.stdout.flush()

    def refresh_label_alphabet(self, input_file):
        old_size = self.label_alphabet_size
        self.label_alphabet.clear(True)
        in_lines = open(input_file, 'r', encoding="utf-8").readlines()
        for line in in_lines:
            if len(line) > 2:
                pairs = line.strip().split()
                label = pairs[-1]
                self.label_alphabet.add(label)
        self.label_alphabet_size = self.label_alphabet.size()
        startS = False
        startB = False
        for label, _ in self.label_alphabet.iteritems():
            if "S-" in label.upper():
                startS = True
            elif "B-" in label.upper():
                startB = True
        if startB:
            if startS:
                self.tagScheme = "BMES"
            else:
                self.tagScheme = "BIO"
        self.fix_alphabet()
        print("Refresh label alphabet finished: old:%s -> new:%s" %
              (old_size, self.label_alphabet_size))

    def build_alphabet(self, input_file):
        in_lines = open(input_file, 'r', encoding="utf-8").readlines()
        seqlen = 0
        for idx in range(len(in_lines)):
            line = in_lines[idx]
            if len(line) > 2:
                pairs = line.strip().split()
                word = pairs[0]
                if self.number_normalized:
                    word = normalize_word(word)
                label = pairs[-1]
                self.label_alphabet.add(label)
                self.word_alphabet.add(word)
                if idx < len(in_lines) - 1 and len(in_lines[idx + 1]) > 2:
                    biword = word + in_lines[idx + 1].strip().split()[0]
                else:
                    biword = word + NULLKEY
                self.biword_alphabet.add(biword)
                # biword_index = self.biword_alphabet.get_index(biword)
                self.biword_count[biword] = self.biword_count.get(biword,
                                                                  0) + 1
                for char in word:
                    self.char_alphabet.add(char)
                #当前句子的长度
                seqlen += 1
            else:
                #出现空行则清零
                seqlen = 0
        #计算各个字表的长度
        self.word_alphabet_size = self.word_alphabet.size()
        self.biword_alphabet_size = self.biword_alphabet.size()
        self.char_alphabet_size = self.char_alphabet.size()
        self.label_alphabet_size = self.label_alphabet.size()
        startS = False
        startB = False
        for label, _ in self.label_alphabet.iteritems():
            if "S-" in label.upper():
                startS = True
            elif "B-" in label.upper():
                startB = True
        if startB:
            if startS:
                self.tagScheme = "BMES"
            else:
                self.tagScheme = "BIO"

    def build_gaz_file(self, gaz_file):
        ## build gaz file,initial read gaz embedding file
        if gaz_file:
            fins = open(gaz_file, 'r', encoding="utf-8").readlines()
            for fin in fins:
                fin = fin.strip().split()[0]
                if fin:
                    self.gaz.insert(fin, "one_source")
            print("Load gaz file: ", gaz_file, " total size:", self.gaz.size())
        else:
            print("Gaz file is None, load nothing")

    #def build_dict_alphabet(

    def build_gaz_alphabet(self, input_file, count=False):
        in_lines = open(input_file, 'r', encoding="utf-8").readlines()
        word_list = []
        for line in in_lines:
            if len(line) > 3:
                word = line.split()[0]
                if self.number_normalized:
                    word = normalize_word(word)
                word_list.append(word)
            else:
                #word_list为当前这个句子的所有字
                w_length = len(word_list)
                entitys = []
                #获取到了句子
                for idx in range(w_length):
                    matched_entity = self.gaz.enumerateMatchList(
                        word_list[idx:])
                    entitys += matched_entity
                    for entity in matched_entity:
                        # print entity, self.gaz.searchId(entity),self.gaz.searchType(entity)
                        self.gaz_alphabet.add(entity)
                        index = self.gaz_alphabet.get_index(entity)

                        self.gaz_count[index] = self.gaz_count.get(
                            index, 0)  ## initialize gaz count
                        #0表示若无想要的关键词则返回0,没有index这一个键值

                if count:
                    entitys.sort(key=lambda x: -len(x))
                    while entitys:
                        longest = entitys[0]
                        longest_index = self.gaz_alphabet.get_index(longest)
                        #最长词的index加1
                        self.gaz_count[longest_index] = self.gaz_count.get(
                            longest_index, 0) + 1
                        #把一个词语覆盖的词全部删掉
                        gazlen = len(longest)
                        for i in range(gazlen):
                            for j in range(i + 1, gazlen + 1):
                                covering_gaz = longest[i:j]
                                if covering_gaz in entitys:
                                    entitys.remove(covering_gaz)
                                    # print('remove:',covering_gaz)
                word_list = []
        print("gaz alphabet size:", self.gaz_alphabet.size())

    def fix_alphabet(self):
        self.word_alphabet.close()
        self.biword_alphabet.close()
        self.char_alphabet.close()
        self.label_alphabet.close()
        self.gaz_alphabet.close()

    def build_word_pretrain_emb(self, emb_path):
        print("build word pretrain emb...")
        self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(
            emb_path, self.word_alphabet, self.word_emb_dim,
            self.norm_word_emb)

    def build_biword_pretrain_emb(self, emb_path):
        print("build biword pretrain emb...")
        self.pretrain_biword_embedding, self.biword_emb_dim = build_pretrain_embedding(
            emb_path, self.biword_alphabet, self.biword_emb_dim,
            self.norm_biword_emb)

    def build_gaz_pretrain_emb(self, emb_path):
        print("build gaz pretrain emb...")
        self.pretrain_gaz_embedding, self.gaz_emb_dim = build_pretrain_embedding(
            emb_path, self.gaz_alphabet, self.gaz_emb_dim, self.norm_gaz_emb)

    def generate_instance_with_gaz(self, input_file, name):
        self.fix_alphabet()
        if name == "train":
            self.train_texts, self.train_Ids = read_instance_with_gaz(
                self.HP_num_layer, input_file, self.gaz, self.word_alphabet,
                self.biword_alphabet, self.biword_count, self.char_alphabet,
                self.gaz_alphabet, self.gaz_count, self.gaz_split,
                self.label_alphabet, self.number_normalized,
                self.MAX_SENTENCE_LENGTH)
        elif name == "dev":
            self.dev_texts, self.dev_Ids = read_instance_with_gaz(
                self.HP_num_layer, input_file, self.gaz, self.word_alphabet,
                self.biword_alphabet, self.biword_count, self.char_alphabet,
                self.gaz_alphabet, self.gaz_count, self.gaz_split,
                self.label_alphabet, self.number_normalized,
                self.MAX_SENTENCE_LENGTH)
        elif name == "test":
            self.test_texts, self.test_Ids = read_instance_with_gaz(
                self.HP_num_layer, input_file, self.gaz, self.word_alphabet,
                self.biword_alphabet, self.biword_count, self.char_alphabet,
                self.gaz_alphabet, self.gaz_count, self.gaz_split,
                self.label_alphabet, self.number_normalized,
                self.MAX_SENTENCE_LENGTH)
        elif name == "raw":
            self.raw_texts, self.raw_Ids = read_instance_with_gaz(
                self.HP_num_layer, input_file, self.gaz, self.word_alphabet,
                self.biword_alphabet, self.biword_count, self.char_alphabet,
                self.gaz_alphabet, self.gaz_count, self.gaz_split,
                self.label_alphabet, self.number_normalized,
                self.MAX_SENTENCE_LENGTH)
        else:
            print(
                "Error: you can only generate train/dev/test instance! Illegal input:%s"
                % (name))

    def write_decoded_results(self, output_file, predict_results, name):
        fout = open(output_file, 'w')
        sent_num = len(predict_results)
        content_list = []
        if name == 'raw':
            content_list = self.raw_texts
        elif name == 'test':
            content_list = self.test_texts
        elif name == 'dev':
            content_list = self.dev_texts
        elif name == 'train':
            content_list = self.train_texts
        else:
            print(
                "Error: illegal name during writing predict result, name should be within train/dev/test/raw !"
            )
        assert (sent_num == len(content_list))
        for idx in range(sent_num):
            sent_length = len(predict_results[idx])
            for idy in range(sent_length):
                ## content_list[idx] is a list with [word, char, label]
                fout.write(content_list[idx][0][idy].encode('utf-8') + " " +
                           predict_results[idx][idy] + '\n')

            fout.write('\n')
        fout.close()
        print("Predict %s result has been written into file. %s" %
              (name, output_file))
Beispiel #3
0
class Data:
    def __init__(self):
        self.MAX_SENTENCE_LENGTH = 250
        self.MAX_WORD_LENGTH = -1
        self.number_normalized = False
        self.norm_word_emb = True
        self.norm_biword_emb = True
        self.word_alphabet = Alphabet('word')
        self.biword_alphabet = Alphabet('biword')
        self.pos_alphabet = Alphabet('pos')
        self.label_alphabet = Alphabet('label', True)

        self.tagScheme = "NoSeg"
        self.char_features = "LSTM"

        self.train_texts = []
        self.dev_texts = []
        self.test_texts = []
        self.raw_texts = []

        self.train_Ids = []
        self.dev_Ids = []
        self.test_Ids = []
        self.raw_Ids = []
        self.use_bigram = False
        self.word_emb_dim = 50
        self.biword_emb_dim = 50

        self.pretrain_word_embedding = None
        self.pretrain_biword_embedding = None
        self.label_size = 0
        self.word_alphabet_size = 0
        self.biword_alphabet_size = 0
        self.label_alphabet_size = 0
        #  hyperparameters
        self.HP_iteration = 100
        self.HP_batch_size = 16
        self.HP_char_hidden_dim = 50
        self.HP_hidden_dim = 200
        self.HP_dropout = 0.2
        self.HP_lstmdropout = 0
        self.HP_lstm_layer = 1
        self.HP_bilstm = True
        self.HP_gpu = False
        self.HP_lr = 0.015
        self.HP_lr_decay = 0.05
        self.HP_clip = 5.0
        self.HP_momentum = 0

        #  attention
        self.tencent_word_embed_dim = 200
        self.pos_embed_dim = 200
        self.cross_domain = False
        self.cross_test = False
        self.use_san = False
        self.use_cnn = False
        self.use_attention = True
        self.pos_to_idx = {}
        self.external_pos = {}
        self.token_replace_prob = {}
        self.use_adam = False
        self.use_bert = False
        self.use_warmup_adam = False
        self.use_sgd = False
        self.use_adadelta = False
        self.use_window = True
        self.mode = 'train'
        self.use_tencent_dic = False

        # cross domain file
        self.computer_file = ""
        self.finance_file = ""
        self.medicine_file = ""
        self.literature_file = ""

    def show_data_summary(self):
        print("DATA SUMMARY START:")
        print("     Tag          scheme: %s" % (self.tagScheme))
        print("     MAX SENTENCE LENGTH: %s" % (self.MAX_SENTENCE_LENGTH))
        print("     MAX   WORD   LENGTH: %s" % (self.MAX_WORD_LENGTH))
        print("     Number   normalized: %s" % (self.number_normalized))
        print("     Use          bigram: %s" % (self.use_bigram))
        print("     Char  alphabet size: %s" % (self.word_alphabet_size))
        print("     BiChar alphabet size: %s" % (self.biword_alphabet_size))
        print("     Label alphabet size: %s" % (self.label_alphabet_size))
        print("     Char embedding size: %s" % (self.word_emb_dim))
        print("     BiChar embedding size: %s" % (self.biword_emb_dim))
        print("     Norm     char   emb: %s" % (self.norm_word_emb))
        print("     Norm     bichar emb: %s" % (self.norm_biword_emb))
        print("     Train instance number: %s" % (len(self.train_texts)))
        print("     Dev   instance number: %s" % (len(self.dev_texts)))
        print("     Test  instance number: %s" % (len(self.test_texts)))
        print("     Raw   instance number: %s" % (len(self.raw_texts)))
        print("     Hyperpara  iteration: %s" % (self.HP_iteration))
        print("     Hyperpara  batch size: %s" % (self.HP_batch_size))
        print("     Hyperpara          lr: %s" % (self.HP_lr))
        print("     Hyperpara    lr_decay: %s" % (self.HP_lr_decay))
        print("     Hyperpara     HP_clip: %s" % (self.HP_clip))
        print("     Hyperpara    momentum: %s" % (self.HP_momentum))
        print("     Hyperpara  hidden_dim: %s" % (self.HP_hidden_dim))
        print("     Hyperpara     dropout: %s" % (self.HP_dropout))
        print("     Hyperpara  lstm_layer: %s" % (self.HP_lstm_layer))
        print("     Hyperpara      bilstm: %s" % (self.HP_bilstm))
        print("     Hyperpara         GPU: %s" % (self.HP_gpu))
        print("     Cross domain: %s" % self.cross_domain)
        print("     Hyperpara  use window: %s" % self.use_window)
        print("DATA SUMMARY END.")
        sys.stdout.flush()

    def build_alphabet(self, input_file):
        in_lines = open(input_file, 'r').readlines()
        for idx in range(len(in_lines)):
            line = in_lines[idx]
            if len(line) > 2:
                pairs = line.strip().split('\t')
                # word = pairs[0].decode('utf-8')
                word = pairs[0]
                if self.number_normalized:
                    word = normalize_word(word)
                label = pairs[-1][0] + '-SEG'
                self.label_alphabet.add(label)
                self.word_alphabet.add(word)
                if idx < len(in_lines) - 1 and len(in_lines[idx + 1]) > 2:
                    # biword = word + in_lines[idx + 1].strip('\t').split()[0].decode('utf-8')
                    biword = word + in_lines[idx + 1].strip('\t').split()[0]
                else:
                    biword = word + NULLKEY
                self.biword_alphabet.add(biword)

        self.word_alphabet_size = self.word_alphabet.size()
        self.biword_alphabet_size = self.biword_alphabet.size()
        self.label_alphabet_size = self.label_alphabet.size()
        startS = False
        startB = False
        for label, _ in self.label_alphabet.iteritems():
            if "S-" in label.upper():
                startS = True
            elif "B-" in label.upper():
                startB = True
        if startB:
            if startS:
                self.tagScheme = "BMES"
            else:
                self.tagScheme = "BIO"

    def fix_alphabet(self):
        self.word_alphabet.close()
        self.biword_alphabet.close()
        self.label_alphabet.close()

    def build_word_pretrain_emb(self, emb_path):
        print("build word pretrain emb...")
        self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(
            emb_path, self.word_alphabet, self.word_emb_dim,
            self.norm_word_emb)

    def build_biword_pretrain_emb(self, emb_path):
        print("build biword pretrain emb...")
        self.pretrain_biword_embedding, self.biword_emb_dim = build_pretrain_embedding(
            emb_path, self.biword_alphabet, self.biword_emb_dim,
            self.norm_biword_emb)

    def build_word_vec_100(self):
        self.pretrain_word_embedding, self.pretrain_biword_embedding = self.get_embedding(
        )
        self.word_emb_dim, self.biword_emb_dim = 100, 100

    # get pre-trained embeddings
    def get_embedding(self, size=100):
        fname = 'data/wordvec_' + str(size)
        print("build pretrain word embedding from: ", fname)
        word_init_embedding = np.zeros(shape=[self.word_alphabet.size(), size])
        bi_word_init_embedding = np.zeros(
            shape=[self.biword_alphabet.size(), size])
        pre_trained = gensim.models.KeyedVectors.load(fname, mmap='r')
        # pre_trained_vocab = set([unicode(w.decode('utf8')) for w in pre_trained.vocab.keys()])
        pre_trained_vocab = set([w for w in pre_trained.vocab.keys()])
        c = 0
        for word, index in self.word_alphabet.iteritems():
            if word in pre_trained_vocab:
                word_init_embedding[index] = pre_trained[word]
            else:
                word_init_embedding[index] = np.random.uniform(-0.5, 0.5, size)
                c += 1

        for word, index in self.biword_alphabet.iteritems():
            bi_word_init_embedding[index] = (
                word_init_embedding[self.word_alphabet.get_index(word[0])] +
                word_init_embedding[self.word_alphabet.get_index(word[1])]) / 2
        # word_init_embedding[word2id[PAD]] = np.zeros(shape=size)
        # bi_word_init_embedding[]
        print('oov character rate %f' % (float(c) / self.word_alphabet.size()))
        return word_init_embedding, bi_word_init_embedding

    def generate_instance(self, input_file, name):
        self.fix_alphabet()
        if name == "train":
            self.train_texts, self.train_Ids = read_instance(
                input_file, self.word_alphabet, self.biword_alphabet,
                self.label_alphabet, self.number_normalized,
                self.MAX_SENTENCE_LENGTH)
        elif name == "dev":
            self.dev_texts, self.dev_Ids = read_instance(
                input_file, self.word_alphabet, self.biword_alphabet,
                self.label_alphabet, self.number_normalized,
                self.MAX_SENTENCE_LENGTH)
        elif name == "test":
            self.test_texts, self.test_Ids = read_instance(
                input_file, self.word_alphabet, self.biword_alphabet,
                self.label_alphabet, self.number_normalized,
                self.MAX_SENTENCE_LENGTH)
        elif name == "raw":
            self.raw_texts, self.raw_Ids = read_instance(
                input_file, self.word_alphabet, self.biword_alphabet,
                self.label_alphabet, self.number_normalized,
                self.MAX_SENTENCE_LENGTH)
        else:
            print(
                "Error: you can only generate train/dev/test instance! Illegal input:%s"
                % (name))

    def write_decoded_results(self, output_file, predict_results, name):
        fout = open(output_file, 'w')
        sent_num = len(predict_results)
        content_list = []
        if name == 'raw':
            content_list = self.raw_texts
        elif name == 'test':
            content_list = self.test_texts
        elif name == 'dev':
            content_list = self.dev_texts
        elif name == 'train':
            content_list = self.train_texts
        else:
            print(
                "Error: illegal name during writing predict result, name should be within train/dev/test/raw !"
            )
        assert (sent_num == len(content_list))
        for idx in range(sent_num):
            sent_length = len(predict_results[idx])
            for idy in range(sent_length):
                ## content_list[idx] is a list with [word, char, label]
                fout.write(content_list[idx][0][idy] + "\t" +
                           predict_results[idx][idy][0] + '\n')

            fout.write('\n')
        fout.close()
        print("Predict %s result has been written into file. %s" %
              (name, output_file))
Beispiel #4
0
class Data(object):
    def __init__(self, args):
        super(Data, self).__init__()
        self.args = args
        self.data_dir = args.data_dir  # './data/gene_term_format_by_sentence.json'
        self.data_ratio = (0.9, 0.05, 0.05)  # total 2000
        self.model_save_dir = args.savemodel  # './saves/model/'
        self.output_dir = args.output  # './saves/output/'
        self.data_save_file = args.savedset  # './saves/data/dat.pkl'

        self.pos_as_feature = args.use_pos
        self.use_elmo = args.use_elmo
        self.elmodim = args.elmodim
        self.pos_emb_dim = args.posdim
        self.useSpanLen = args.use_len
        self.use_sentence_att = args.use_sent_att
        self.use_char = True
        self.ranking = 1

        self.word_alphabet = Alphabet('word')
        self.char_alphabet = Alphabet('character')
        self.ptag_alphabet = Alphabet('tag')
        self.label_alphabet = Alphabet('label', label=True)
        self.seqlabel_alphabet = Alphabet('span_label', label=True)

        self.word_alphabet_size = 0
        self.char_alphabet_size = 0
        self.ptag_alphabet_size = 0
        self.label_alphabet_size = 0
        self.seqlabel_alphabet_size = 0

        self.max_sentence_length = 500

        self.term_truples = []

        self.sent_texts = []
        self.chars = []
        self.lengths = []
        self.ptags = []
        self.seq_labels = []

        self.word_ids_sent = []
        self.char_id_sent = []
        self.tag_ids_sent = []
        self.label_ids_sent = []
        self.seq_labels_ids = []

        self.longSpan = True
        self.shortSpan = True
        self.termratio = args.term_ratio
        self.term_span = args.max_length

        self.word_feature_extractor = "LSTM"  ## "LSTM"/"CNN"/"GRU"/
        self.char_feature_extractor = "CNN"  ## "LSTM"/"CNN"/"GRU"/None

        # training
        self.optimizer = 'Adam'  # "SGD"/"AdaGrad"/"AdaDelta"/"RMSProp"/"Adam"
        self.training = True
        self.average_batch_loss = True
        self.evaluate_every = args.evaluate_every  # 10 # evaluate every n batches
        self.print_every = args.print_every
        self.silence = True
        self.earlystop = args.early_stop

        # Embeddings
        self.word_emb_dir = args.wordemb  # './data/glove.6B.100d.txt' # None #'../data/glove.6b.100d.txt'
        self.char_emb_dir = args.charemb
        self.word_emb_dim = 50
        self.char_emb_dim = 30
        self.spamEm_dim = 30
        self.norm_word_emb = False
        self.norm_char_emb = False
        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None

        # HP
        self.HP_char_hidden_dim = 50
        self.HP_hidden_dim = 100
        self.HP_cnn_layer = 2
        self.HP_batch_size = 100
        self.HP_epoch = 100
        self.HP_lr = args.lr
        self.HP_lr_decay = 0.05
        self.HP_clip = None
        self.HP_l2 = 1e-8
        self.HP_dropout = args.dropout
        self.HP_lstm_layer = 2
        self.HP_bilstm = True
        self.HP_gpu = args.use_gpu  # False#True
        self.HP_term_span = 6
        self.HP_momentum = 0

        # init data
        self.build_vocabs()
        self.all_instances = self.load_data()
        self.load_pretrain_emb()

    def build_vocabs(self):
        ''''''
        with open(self.data_dir, 'r') as filin:
            filelines = filin.readlines()

        for lin_id, lin_cnt in enumerate(filelines):
            lin_cnt = lin_cnt.strip()
            line = json.loads(lin_cnt)
            words = line['words']
            tags = line['tags']
            terms = line['terms']
            for word in words:
                self.word_alphabet.add(word)
                for char in word:
                    self.char_alphabet.add(char)
            for tag in tags:
                self.ptag_alphabet.add(tag)
            self.sent_texts.append(words)
            self.ptags.append(tags)
            assert len(words) == len(tags)
            self.lengths.append(len(words))
            seq_label, termple = self.reformat_label(words, terms)
            self.seq_labels.append(seq_label)

            if len(terms) > 0:
                tmp_terms = []
                for itm in termple:
                    tmp_terms.append([itm[0], itm[1], itm[2]])
                    self.label_alphabet.add(itm[2])
                self.term_truples.append(tmp_terms)
            else:
                self.term_truples.append([[-1, -1, 'None']])
                self.label_alphabet.add('None')

        for ter in self.seq_labels:
            for ater in ter:
                for ate in ater:
                    self.seqlabel_alphabet.add(ate)

        self.word_alphabet_size = self.word_alphabet.size()
        self.char_alphabet_size = self.char_alphabet.size()
        self.ptag_alphabet_size = self.ptag_alphabet.size()
        self.seqlabel_alphabet_size = self.seqlabel_alphabet.size()
        self.label_alphabet_size = self.label_alphabet.size()
        self.close_alphabet()

    def load_pretrain_emb(self):
        ''''''
        if self.word_emb_dir:
            print('Loading pretrained Word Embedding from {}'.format(
                self.word_emb_dir))
            self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(
                self.word_emb_dir, self.word_alphabet, self.word_emb_dim,
                self.norm_word_emb)
        if self.char_emb_dir:
            print('Loading pretrained Char Embedding from {}'.format(
                self.char_emb_dir))
            self.pretrain_word_embedding, self.char_emb_dim = build_pretrain_embedding(
                self.char_emb_dir, self.char_alphabet, self.char_emb_dim,
                self.norm_char_emb)

    def load_data(self):
        ''''''
        all_instances = []
        assert len(self.sent_texts) == len(self.term_truples) == len(
            self.ptags)
        for sent_text, ptag, term_truple, seqlabel in zip(
                self.sent_texts, self.ptags, self.term_truples,
                self.seq_labels):
            self.word_ids_sent.append(self.word_alphabet.get_indexs(sent_text))
            sent_char = []
            sent_char_ids = []
            for word in sent_text:
                char_list = list(word)
                sent_char.append(char_list)
                char_ids = self.char_alphabet.get_indexs(char_list)
                sent_char_ids.append(char_ids)
            seqLabel_ids = [
                self.seqlabel_alphabet.get_indexs(seqlab)
                for seqlab in seqlabel
            ]
            self.seq_labels_ids.append(seqLabel_ids)
            self.chars.append(sent_char)
            self.char_id_sent.append(sent_char_ids)
            self.tag_ids_sent.append(self.ptag_alphabet.get_indexs(ptag))
            term_truple = [[
                term[0], term[1],
                self.label_alphabet.get_index(term[2])
            ] for term in term_truple]
            self.label_ids_sent.append(term_truple)
            all_instances.append([
                self.word_ids_sent[-1], sent_char_ids, self.tag_ids_sent[-1],
                [term_truple, seqLabel_ids], sent_text
            ])
        return all_instances

    def reformat_label(self, words, terms):
        label = [[] for i in range(len(words))]
        termtruple = []
        if len(terms) > 0:
            for term in terms:
                beg = term[0]
                end = term[1]
                lab_ = term[2]
                termtruple.append((beg, end, lab_))
                if beg == end:
                    label[beg].append('S')
                    continue
                label[beg].append('B')
                label[end].append('E')
                if end - beg > 1:
                    for itm in range(beg + 1, end):
                        label[itm].append('I')
        for slab in label:
            if slab == []:
                slab.append('O')
        label = [list(set(lab)) for lab in label]
        return label, termtruple

    def restore(self, data_file):
        print('Loading data from %s' % data_file)
        with open(data_file, 'rb') as filin:
            obj_dict = pkl.load(filin)
            self.__dict__.update(obj_dict)

    def save(self, save_file):
        print('Saving data to %s' % save_file)
        with open(save_file, 'wb') as filout:
            pkl.dump(self.__dict__, filout, 2)

    def close_alphabet(self):
        self.word_alphabet.close()
        self.ptag_alphabet.close()
        self.label_alphabet.close()
        self.seqlabel_alphabet.close()
        self.char_alphabet.close()
        return
Beispiel #5
0
class Data(object):
    def __init__(self, data_config_file, alphabet_path, if_train=True):
        if if_train:
            with open(data_config_file, 'r') as rf:
                self.data_config = yaml.load(rf, Loader=yaml.FullLoader)
            # init data file
            mode = self.data_config['mode']
            self.data_file = os.path.join(ROOT_PATH,
                                          self.data_config['data'][mode])
            # init ac tree
            specific_words_file = os.path.join(
                ROOT_PATH, self.data_config['specific_words_file'])
            self.trees = Trees.build_trees(specific_words_file)
            # init alphabet
            self.char_alphabet = Alphabet('char')
            self.intent_alphabet = Alphabet('intent')
            self.label_alphabet = Alphabet('label', label=True)
            self.char_alphabet_size, self.intent_alphabet_size, self.label_alphabet_size = -1, -1, -1
            # pad length
            self.char_max_length = self.data_config['char_max_length']
            # read data file
            with open(self.data_file, 'r') as rf:
                self.corpus = rf.readlines()
            self.build_alphabet(alphabet_path)
            self.texts, self.ids = self.read_instance()
            self.train_texts, self.train_ids, self.dev_texts, self.dev_ids, self.test_texts, self.test_ids = self.sample_split(
            )
        else:  # inference use
            self.char_alphabet = Alphabet('char', keep_growing=False)
            self.intent_alphabet = Alphabet('intent', keep_growing=False)
            self.label_alphabet = Alphabet('label',
                                           label=True,
                                           keep_growing=False)

    def build_alphabet(self, alphabet_path):
        for line in self.corpus:
            line = ast.literal_eval(line)
            char, char_label, seg_list, intent = line['char'], line[
                'char_label'], line['word'], line['intent']
            for word in seg_list:
                # lexicon
                lexi_feat = []
                for lexi_type, lb in self.trees.lexi_trees.items():
                    lexi_feat.append(lb.search(word))
                for n in range(len(lexi_feat)):
                    if lexi_feat[n] is None or lexi_feat[n] == '_STEM_':
                        lexi_feat[n] = 0
                    else:
                        lexi_feat[n] = 1
                lexi_feat = ''.join([str(i) for i in lexi_feat])
                # 抽象成一个字符
                self.char_alphabet.add(lexi_feat)
            # char
            for c in char:
                self.char_alphabet.add(normalize_word(c))
            # intent
            self.intent_alphabet.add(intent)
            # label
            for label in char_label:
                self.label_alphabet.add(label)
        # alphabet_size
        self.char_alphabet_size = self.char_alphabet.size()
        self.intent_alphabet_size = self.intent_alphabet.size()
        self.label_alphabet_size = self.label_alphabet.size()
        # close alphabet
        self.fix_alphabet()

        # write alphabet:
        if not os.path.exists(alphabet_path):
            with open(alphabet_path, 'wb') as wbf:
                pickle.dump(self.char_alphabet.instance2index, wbf)
                pickle.dump(self.intent_alphabet.instance2index, wbf)
                pickle.dump(self.label_alphabet.instance2index, wbf)
                pickle.dump(self.label_alphabet.instances, wbf)
                pickle.dump(self.char_alphabet_size, wbf)
                pickle.dump(self.intent_alphabet_size, wbf)
                pickle.dump(self.label_alphabet_size, wbf)

    def read_instance(self):
        """
		这里读取完整读数据,不做截断,functions.py中做截断
		:return:
		"""
        texts, ids = [], []
        for idx, line in enumerate(self.corpus):
            line = ast.literal_eval(line)
            intent_id_list = []
            # word:'0010000' -> 合并成一个标签
            seq_char, seq_char_id_list, seq_label, seq_label_id_list = [], [], [], []
            char, char_label, seg_list, intent = line['char'], line[
                'char_label'], line['word'], line['intent']
            # 存储one-hot形式的属性特征
            lexicons = []
            # 记录字符的index
            word_indices = []
            start = 0
            flag = True  # 判断跳至上一循环
            for word in seg_list:
                if flag is True:
                    end = start + len(word)
                    lexi_feat = []
                    for lexi_type, lb in self.trees.lexi_trees.items():
                        lexi_feat.append(lb.search(word))
                    for n in range(len(lexi_feat)):
                        if lexi_feat[n] is None or lexi_feat[n] == '_STEM_':
                            lexi_feat[n] = 0
                        else:
                            lexi_feat[n] = 1
                    lexi_feat = ''.join([str(i) for i in lexi_feat])
                    lexicons.append(lexi_feat)
                    word_indices.append([start, end])

                    # char
                    # '0010000'
                    if '1' in lexi_feat:
                        seq_char.append(lexi_feat)
                        seq_char_id_list.append(
                            self.char_alphabet.get_index(lexi_feat))
                        # ["B-room", "I-room", "I-room"]
                        specific_word_label = char_label[start:end]
                        tmp_label = [
                            swl.split('-')[-1] for swl in specific_word_label
                        ]
                        if len(set(tmp_label)) > 1:
                            # 判断是否过滤该条数据
                            print('Be filtered: %s' % line['text'], word,
                                  tmp_label)
                            flag = False
                        else:
                            assert len(set(tmp_label)) == 1
                            if tmp_label[0] == 'O':
                                tmp_label = 'O'
                            else:
                                tmp_label = 'B' + '-' + tmp_label[0]
                            seq_label += [tmp_label]
                            seq_label_id_list += [
                                self.label_alphabet.get_index(tmp_label)
                            ]
                    # '0000000'
                    else:
                        for c in word:
                            seq_char.append(c)
                            seq_char_id_list.append(
                                self.char_alphabet.get_index(
                                    normalize_word(c)))
                        seq_label += char_label[start:end]
                        seq_label_id_list += [
                            self.label_alphabet.get_index(cl)
                            for cl in char_label[start:end]
                        ]

                    start = end
                else:
                    break  # 跳至下一个corpus

            intent_id_list.append(self.intent_alphabet.get_index(intent))

            if idx % 10000 == 0:
                logger.info('read instance : %s' % idx)

            if flag is True:
                # text, char, intent, sequence_label
                texts.append([line['text'], seq_char, intent, seq_label])
                ids.append(
                    [seq_char_id_list, intent_id_list, seq_label_id_list])

        # 新形式的corpus的保存下来,方便查bug
        output_path = self.data_config['data']['output']
        with open(output_path, 'w') as wf:
            for text in texts:
                line_data = dict()
                line_data['text'] = text[0]
                line_data['char'] = text[1]
                line_data['intent'] = text[2]
                line_data['char_label'] = text[-1]
                wf.write(json.dumps(line_data, ensure_ascii=False) + '\n')

        return texts, ids

    def fix_alphabet(self):
        self.char_alphabet.close()
        self.intent_alphabet.close()
        self.label_alphabet.close()

    # data sampling
    def sample_split(self):
        sampling_rate = self.data_config['sampling_rate']
        indexes = list(range(len(self.ids)))
        random.shuffle(indexes)
        shuffled_texts = [self.texts[i] for i in indexes]
        shuffled_ids = [self.ids[i] for i in indexes]
        logger.info('Top 10 shuffled indexes: %s' % indexes[:10])

        n = int(len(shuffled_ids) * sampling_rate)
        dev_texts, dev_ids = shuffled_texts[:n], shuffled_ids[:n]
        test_texts, test_ids = shuffled_texts[n:2 * n], shuffled_ids[n:2 * n]
        train_texts, train_ids = shuffled_texts[2 * n:], shuffled_ids[2 * n:]

        return train_texts, train_ids, dev_texts, dev_ids, test_texts, test_ids