コード例 #1
0
class Data:
    def __init__(self):
        self.MAX_SENTENCE_LENGTH = 250
        self.MAX_WORD_LENGTH = -1
        self.number_normalized = True
        self.norm_word_emb = False
        self.norm_char_emb = False
        self.word_alphabet = Alphabet('word')
        self.char_alphabet = Alphabet('character')

        self.feature_name = []
        self.feature_alphabets = []
        self.feature_num = len(self.feature_alphabets)
        self.feat_config = None


        self.label_alphabet = Alphabet('label',True)
        self.tagScheme = "NoSeg" ## BMES/BIO
        
        self.seg = True

        ### I/O
        self.train_dir = None 
        self.dev_dir = None 
        self.test_dir = None 
        self.raw_dir = None

        self.decode_dir = None
        self.dset_dir = None ## data vocabulary related file
        self.model_dir = None ## model save  file
        self.load_model_dir = None ## model load file

        self.word_emb_dir = None 
        self.char_emb_dir = None
        self.feature_emb_dirs = []

        self.train_texts = []
        self.dev_texts = []
        self.test_texts = []
        self.raw_texts = []

        self.train_Ids = []
        self.dev_Ids = []
        self.test_Ids = []
        self.raw_Ids = []

        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None
        self.pretrain_feature_embeddings = []

        self.label_size = 0
        self.word_alphabet_size = 0
        self.char_alphabet_size = 0
        self.label_alphabet_size = 0
        self.feature_alphabet_sizes = []
        self.feature_emb_dims = []
        self.norm_feature_embs = []
        self.word_emb_dim = 50
        self.char_emb_dim = 30

        ###Networks
        self.word_feature_extractor = "LSTM" ## "LSTM"/"CNN"/"GRU"/
        self.use_char = True
        self.char_feature_extractor = "CNN" ## "LSTM"/"CNN"/"GRU"/None
        self.use_crf = True
        self.nbest = None
        
        ## Training
        self.average_batch_loss = False
        self.optimizer = "SGD" ## "SGD"/"AdaGrad"/"AdaDelta"/"RMSProp"/"Adam"
        self.status = "train"
        ### Hyperparameters
        self.HP_cnn_layer = 4
        self.HP_iteration = 100
        self.HP_batch_size = 10
        self.HP_char_hidden_dim = 50
        self.HP_hidden_dim = 200
        self.HP_dropout = 0.5
        self.HP_lstm_layer = 1
        self.HP_bilstm = True
        
        self.HP_gpu = False
        self.HP_lr = 0.015
        self.HP_lr_decay = 0.05
        self.HP_clip = None
        self.HP_momentum = 0
        self.HP_l2 = 1e-8
        
    def show_data_summary(self):
        print("++"*50)
        print("DATA SUMMARY START:")
        print(" I/O:")
        print("     Tag          scheme: %s"%(self.tagScheme))
        print("     MAX SENTENCE LENGTH: %s"%(self.MAX_SENTENCE_LENGTH))
        print("     MAX   WORD   LENGTH: %s"%(self.MAX_WORD_LENGTH))
        print("     Number   normalized: %s"%(self.number_normalized))
        print("     Word  alphabet size: %s"%(self.word_alphabet_size))
        print("     Char  alphabet size: %s"%(self.char_alphabet_size))
        print("     Label alphabet size: %s"%(self.label_alphabet_size))
        print("     Word embedding  dir: %s"%(self.word_emb_dir))
        print("     Char embedding  dir: %s"%(self.char_emb_dir))
        print("     Word embedding size: %s"%(self.word_emb_dim))
        print("     Char embedding size: %s"%(self.char_emb_dim))
        print("     Norm   word     emb: %s"%(self.norm_word_emb))
        print("     Norm   char     emb: %s"%(self.norm_char_emb))
        print("     Train  file directory: %s"%(self.train_dir))
        print("     Dev    file directory: %s"%(self.dev_dir))
        print("     Test   file directory: %s"%(self.test_dir))
        print("     Raw    file directory: %s"%(self.raw_dir))
        print("     Dset   file directory: %s"%(self.dset_dir))
        print("     Model  file directory: %s"%(self.model_dir))
        print("     Loadmodel   directory: %s"%(self.load_model_dir))
        print("     Decode file directory: %s"%(self.decode_dir))
        print("     Train instance number: %s"%(len(self.train_texts)))
        print("     Dev   instance number: %s"%(len(self.dev_texts)))
        print("     Test  instance number: %s"%(len(self.test_texts)))
        print("     Raw   instance number: %s"%(len(self.raw_texts)))
        print("     FEATURE num: %s"%(self.feature_num))
        for idx in range(self.feature_num):
            print("         Fe: %s  alphabet  size: %s"%(self.feature_alphabets[idx].name, self.feature_alphabet_sizes[idx]))
            print("         Fe: %s  embedding  dir: %s"%(self.feature_alphabets[idx].name, self.feature_emb_dirs[idx]))
            print("         Fe: %s  embedding size: %s"%(self.feature_alphabets[idx].name, self.feature_emb_dims[idx]))
            print("         Fe: %s  norm       emb: %s"%(self.feature_alphabets[idx].name, self.norm_feature_embs[idx]))
        print(" "+"++"*20)
        print(" Model Network:")
        print("     Model        use_crf: %s"%(self.use_crf))
        print("     Model word extractor: %s"%(self.word_feature_extractor))
        print("     Model       use_char: %s"%(self.use_char))
        if self.use_char:
            print("     Model char extractor: %s"%(self.char_feature_extractor))
            print("     Model char_hidden_dim: %s"%(self.HP_char_hidden_dim))
        print(" "+"++"*20)
        print(" Training:")
        print("     Optimizer: %s"%(self.optimizer))
        print("     Iteration: %s"%(self.HP_iteration))
        print("     BatchSize: %s"%(self.HP_batch_size))
        print("     Average  batch   loss: %s"%(self.average_batch_loss))

        print(" "+"++"*20)
        print(" Hyperparameters:")
        
        print("     Hyper              lr: %s"%(self.HP_lr))
        print("     Hyper        lr_decay: %s"%(self.HP_lr_decay))
        print("     Hyper         HP_clip: %s"%(self.HP_clip))
        print("     Hyper        momentum: %s"%(self.HP_momentum))
        print("     Hyper              l2: %s"%(self.HP_l2))
        print("     Hyper      hidden_dim: %s"%(self.HP_hidden_dim))
        print("     Hyper         dropout: %s"%(self.HP_dropout))
        print("     Hyper      lstm_layer: %s"%(self.HP_lstm_layer))
        print("     Hyper          bilstm: %s"%(self.HP_bilstm))
        print("     Hyper             GPU: %s"%(self.HP_gpu))   
        print("DATA SUMMARY END.")
        print("++"*50)
        sys.stdout.flush()


    def initial_feature_alphabets(self):
        items = open(self.train_dir,'r').readline().strip('\n').split()
        total_column = len(items)
        if total_column > 2:
            for idx in range(1, total_column-1):
                feature_prefix = items[idx].split(']',1)[0]+"]"
                self.feature_alphabets.append(Alphabet(feature_prefix))
                self.feature_name.append(feature_prefix)
                print "Find feature: ", feature_prefix 
        self.feature_num = len(self.feature_alphabets)
        self.pretrain_feature_embeddings = [None]*self.feature_num
        self.feature_emb_dims = [20]*self.feature_num
        self.feature_emb_dirs = [None]*self.feature_num 
        self.norm_feature_embs = [False]*self.feature_num
        self.feature_alphabet_sizes = [0]*self.feature_num
        if self.feat_config:
            for idx in range(self.feature_num):
                if self.feature_name[idx] in self.feat_config:
                    self.feature_emb_dims[idx] = self.feat_config[self.feature_name[idx]]['emb_size']
                    self.feature_emb_dirs[idx] = self.feat_config[self.feature_name[idx]]['emb_dir']
                    self.norm_feature_embs[idx] = self.feat_config[self.feature_name[idx]]['emb_norm']
        # exit(0)


    def build_alphabet(self, input_file):
        in_lines = open(input_file,'r').readlines()
        for line in in_lines:
            if len(line) > 2:
                pairs = line.strip().split()
                word = pairs[0].decode('utf-8')
                if self.number_normalized:
                    word = normalize_word(word)
                label = pairs[-1]
                self.label_alphabet.add(label)
                self.word_alphabet.add(word)
                ## build feature alphabet 
                for idx in range(self.feature_num):
                    feat_idx = pairs[idx+1].split(']',1)[-1]
                    self.feature_alphabets[idx].add(feat_idx)
                for char in word:
                    self.char_alphabet.add(char)
        self.word_alphabet_size = self.word_alphabet.size()
        self.char_alphabet_size = self.char_alphabet.size()
        self.label_alphabet_size = self.label_alphabet.size()
        for idx in range(self.feature_num):
            self.feature_alphabet_sizes[idx] = self.feature_alphabets[idx].size()
        startS = False
        startB = False
        for label,_ in self.label_alphabet.iteritems():
            if "S-" in label.upper():
                startS = True
            elif "B-" in label.upper():
                startB = True
        if startB:
            if startS:
                self.tagScheme = "BMES"
            else:
                self.tagScheme = "BIO"


    def fix_alphabet(self):
        self.word_alphabet.close()
        self.char_alphabet.close()
        self.label_alphabet.close() 
        for idx in range(self.feature_num):
            self.feature_alphabets[idx].close()      


    def build_pretrain_emb(self):
        if self.word_emb_dir:
            print("Load pretrained word embedding, norm: %s, dir: %s"%(self.norm_word_emb, self.word_emb_dir))
            self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(self.word_emb_dir, self.word_alphabet, self.word_emb_dim, self.norm_word_emb)
        if self.char_emb_dir:
            print("Load pretrained char embedding, norm: %s, dir: %s"%(self.norm_char_emb, self.char_emb_dir))
            self.pretrain_char_embedding, self.char_emb_dim = build_pretrain_embedding(self.char_emb_dir, self.char_alphabet, self.char_emb_dim, self.norm_char_emb)
        for idx in range(self.feature_num):
            if self.feature_emb_dirs[idx]:
                print("Load pretrained feature %s embedding:, norm: %s, dir: %s"%(self.feature_name[idx], self.norm_feature_embs[idx], self.feature_emb_dirs[idx]))
                self.pretrain_feature_embeddings[idx], self.feature_emb_dims[idx] = build_pretrain_embedding(self.feature_emb_dirs[idx], self.feature_alphabets[idx], self.feature_emb_dims[idx], self.norm_feature_embs[idx])


    def generate_instance(self, name):
        self.fix_alphabet()
        if name == "train":
            self.train_texts, self.train_Ids = read_instance(self.train_dir, self.word_alphabet, self.char_alphabet, self.feature_alphabets, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "dev":
            self.dev_texts, self.dev_Ids = read_instance(self.dev_dir, self.word_alphabet, self.char_alphabet, self.feature_alphabets, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "test":
            self.test_texts, self.test_Ids = read_instance(self.test_dir, self.word_alphabet, self.char_alphabet, self.feature_alphabets, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "raw":
            self.raw_texts, self.raw_Ids = read_instance(self.raw_dir, self.word_alphabet, self.char_alphabet, self.feature_alphabets, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
        else:
            print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))


    def write_decoded_results(self, predict_results, name):
        fout = open(self.decode_dir,'w')
        sent_num = len(predict_results)
        content_list = []
        if name == 'raw':
           content_list = self.raw_texts
        elif name == 'test':
            content_list = self.test_texts
        elif name == 'dev':
            content_list = self.dev_texts
        elif name == 'train':
            content_list = self.train_texts
        else:
            print("Error: illegal name during writing predict result, name should be within train/dev/test/raw !")
        assert(sent_num == len(content_list))
        for idx in range(sent_num):
            sent_length = len(predict_results[idx])
            for idy in range(sent_length):
                ## content_list[idx] is a list with [word, char, label]
                fout.write(content_list[idx][0][idy].encode('utf-8') + " " + predict_results[idx][idy] + '\n')
            fout.write('\n')
        fout.close()
        print("Predict %s result has been written into file. %s"%(name, self.decode_dir))


    def load(self,data_file):
        f = open(data_file, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)

    def save(self,save_file):
        f = open(save_file, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()



    def write_nbest_decoded_results(self, predict_results, pred_scores, name):
        ## predict_results : [whole_sent_num, nbest, each_sent_length]
        ## pred_scores: [whole_sent_num, nbest]
        fout = open(self.decode_dir,'w')
        sent_num = len(predict_results)
        content_list = []
        if name == 'raw':
           content_list = self.raw_texts
        elif name == 'test':
            content_list = self.test_texts
        elif name == 'dev':
            content_list = self.dev_texts
        elif name == 'train':
            content_list = self.train_texts
        else:
            print("Error: illegal name during writing predict result, name should be within train/dev/test/raw !")
        assert(sent_num == len(content_list))
        assert(sent_num == len(pred_scores))
        for idx in range(sent_num):
            sent_length = len(predict_results[idx][0])
            nbest = len(predict_results[idx])
            score_string = "# "
            for idz in range(nbest):
                score_string += format(pred_scores[idx][idz], '.4f')+" "
            fout.write(score_string.strip() + "\n")

            for idy in range(sent_length):
                label_string = content_list[idx][0][idy].encode('utf-8') + " "
                for idz in range(nbest):
                    label_string += predict_results[idx][idz][idy]+" "
                label_string = label_string.strip() + "\n"
                fout.write(label_string)
            fout.write('\n')
        fout.close()
        print("Predict %s %s-best result has been written into file. %s"%(name,nbest, self.decode_dir))


    def read_config(self,config_file):
        config = config_file_to_dict(config_file)
        ## read data:
        the_item = 'train_dir'
        if the_item in config:
            self.train_dir = config[the_item]
        the_item = 'dev_dir'
        if the_item in config:
            self.dev_dir = config[the_item]
        the_item = 'test_dir'
        if the_item in config:
            self.test_dir = config[the_item]
        the_item = 'raw_dir'
        if the_item in config:
            self.raw_dir = config[the_item]
        the_item = 'decode_dir'
        if the_item in config:
            self.decode_dir = config[the_item]
        the_item = 'dset_dir'
        if the_item in config:
            self.dset_dir = config[the_item]
        the_item = 'model_dir'
        if the_item in config:
            self.model_dir = config[the_item]
        the_item = 'load_model_dir'
        if the_item in config:
            self.load_model_dir = config[the_item]

        the_item = 'word_emb_dir'
        if the_item in config:
            self.word_emb_dir = config[the_item]
        the_item = 'char_emb_dir'
        if the_item in config:
            self.char_emb_dir = config[the_item]


        the_item = 'MAX_SENTENCE_LENGTH'
        if the_item in config:
            self.MAX_SENTENCE_LENGTH = int(config[the_item])
        the_item = 'MAX_WORD_LENGTH'
        if the_item in config:
            self.MAX_WORD_LENGTH = int(config[the_item])

        the_item = 'norm_word_emb'
        if the_item in config:
            self.norm_word_emb = str2bool(config[the_item])
        the_item = 'norm_char_emb'
        if the_item in config:
            self.norm_char_emb = str2bool(config[the_item])
        the_item = 'number_normalized'
        if the_item in config:
            self.number_normalized = str2bool(config[the_item])


        the_item = 'seg'
        if the_item in config:
            self.seg = str2bool(config[the_item])
        the_item = 'word_emb_dim'
        if the_item in config:
            self.word_emb_dim = int(config[the_item])
        the_item = 'char_emb_dim'
        if the_item in config:
            self.char_emb_dim = int(config[the_item])

        ## read network:
        the_item = 'use_crf'
        if the_item in config:
            self.use_crf = str2bool(config[the_item])
        the_item = 'use_char'
        if the_item in config:
            self.use_char = str2bool(config[the_item])
        the_item = 'word_seq_feature'
        if the_item in config:
            self.word_feature_extractor = config[the_item]
        the_item = 'char_seq_feature'
        if the_item in config:
            self.char_feature_extractor = config[the_item]
        the_item = 'nbest'
        if the_item in config:
            self.nbest = int(config[the_item])

        the_item = 'feature'
        if the_item in config:
            self.feat_config = config[the_item] ## feat_config is a dict 






        ## read training setting:
        the_item = 'optimizer'
        if the_item in config:
            self.optimizer = config[the_item]
        the_item = 'ave_batch_loss'
        if the_item in config:
            self.average_batch_loss = str2bool(config[the_item])
        the_item = 'status'
        if the_item in config:
            self.status = config[the_item]

        ## read Hyperparameters:
        the_item = 'cnn_layer'
        if the_item in config:
            self.HP_cnn_layer = int(config[the_item])
        the_item = 'iteration'
        if the_item in config:
            self.HP_iteration = int(config[the_item])
        the_item = 'batch_size'
        if the_item in config:
            self.HP_batch_size = int(config[the_item])

        the_item = 'char_hidden_dim'
        if the_item in config:
            self.HP_char_hidden_dim = int(config[the_item])
        the_item = 'hidden_dim'
        if the_item in config:
            self.HP_hidden_dim = int(config[the_item])
        the_item = 'dropout'
        if the_item in config:
            self.HP_dropout = float(config[the_item])
        the_item = 'lstm_layer'
        if the_item in config:
            self.HP_lstm_layer = int(config[the_item])
        the_item = 'bilstm'
        if the_item in config:
            self.HP_bilstm = str2bool(config[the_item])

        the_item = 'gpu'
        if the_item in config:
            self.HP_gpu = str2bool(config[the_item])
        the_item = 'learning_rate'
        if the_item in config:
            self.HP_lr = float(config[the_item])
        the_item = 'lr_decay'
        if the_item in config:
            self.HP_lr_decay = float(config[the_item])
        the_item = 'clip'
        if the_item in config:
            self.HP_clip = float(config[the_item])
        the_item = 'momentum'
        if the_item in config:
            self.HP_momentum = float(config[the_item])
        the_item = 'l2'
        if the_item in config:
            self.HP_l2 = float(config[the_item])
コード例 #2
0
ファイル: data.py プロジェクト: foxlf823/e2e_ner_re
class Data:
    def __init__(self):
        self.MAX_SENTENCE_LENGTH = 250
        self.MAX_WORD_LENGTH = -1
        self.number_normalized = True
        self.norm_word_emb = False
        self.norm_char_emb = False
        self.word_alphabet = Alphabet('word')
        self.char_alphabet = Alphabet('character')

        self.feature_name = []
        self.feature_alphabets = []
        self.feature_num = len(self.feature_alphabets)
        self.feat_config = None
        self.feature_name2id = {}


        self.label_alphabet = Alphabet('label',True)
        self.tagScheme = "NoSeg" ## BMES/BIO
        
        self.seg = True

        ### I/O
        self.train_dir = None 
        self.dev_dir = None 
        self.test_dir = None


        self.model_dir = None ## model save  file


        self.word_emb_dir = None 
        self.char_emb_dir = None
        self.feature_emb_dirs = []

        self.train_texts = []
        self.dev_texts = []
        self.test_texts = []


        self.train_Ids = []
        self.dev_Ids = []
        self.test_Ids = []


        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None
        self.pretrain_feature_embeddings = []

        self.label_size = 0
        self.word_alphabet_size = 0
        self.char_alphabet_size = 0
        self.label_alphabet_size = 0
        self.feature_alphabet_sizes = []
        self.feature_emb_dims = []
        self.norm_feature_embs = []
        self.word_emb_dim = 50
        self.char_emb_dim = 30

        ###Networks
        self.word_feature_extractor = "LSTM" ## "LSTM"/"CNN"/"GRU"/
        self.use_char = True
        self.char_feature_extractor = "CNN" ## "LSTM"/"CNN"/"GRU"/None
        self.use_crf = True
        self.nbest = None
        
        ## Training
        self.average_batch_loss = False

        ### Hyperparameters
        self.HP_cnn_layer = 4
        self.HP_iteration = 100
        self.HP_batch_size = 10
        self.HP_char_hidden_dim = 50
        self.HP_hidden_dim = 200
        self.HP_dropout = 0.5
        self.HP_lstm_layer = 1
        self.HP_bilstm = True
        
        self.HP_gpu = False
        self.HP_lr = 0.015
        self.HP_lr_decay = 0.05
        self.HP_clip = None
        self.HP_momentum = 0
        self.HP_l2 = 1e-8

        # both
        self.full_data = False
        self.tune_wordemb = False

        # relation
        self.pretrain = None
        self.max_seq_len = 500
        self.pad_idx = 1
        self.sent_window = 3
        self.output =None
        self.unk_ratio=1
        self.seq_feature_size=256
        self.max_epoch = 100
        self.feature_extractor=None

        self.re_feature_name = []
        self.re_feature_name2id = {}
        self.re_feature_alphabets = []
        self.re_feature_num = len(self.re_feature_alphabets)
        self.re_feat_config = None

        self.re_train_X = []
        self.re_dev_X = []
        self.re_test_X = []
        self.re_train_Y = []
        self.re_dev_Y = []
        self.re_test_Y = []

        
    def show_data_summary(self):
        print("++"*50)
        print("DATA SUMMARY START:")
        print(" I/O:")
        print("     Tag          scheme: %s"%(self.tagScheme))
        print("     MAX SENTENCE LENGTH: %s"%(self.MAX_SENTENCE_LENGTH))
        print("     MAX   WORD   LENGTH: %s"%(self.MAX_WORD_LENGTH))
        print("     Number   normalized: %s"%(self.number_normalized))
        print("     Word  alphabet size: %s"%(self.word_alphabet_size))
        print("     Char  alphabet size: %s"%(self.char_alphabet_size))
        print("     Label alphabet size: %s"%(self.label_alphabet_size))
        print("     Word embedding  dir: %s"%(self.word_emb_dir))
        print("     Char embedding  dir: %s"%(self.char_emb_dir))
        print("     Word embedding size: %s"%(self.word_emb_dim))
        print("     Char embedding size: %s"%(self.char_emb_dim))
        print("     Norm   word     emb: %s"%(self.norm_word_emb))
        print("     Norm   char     emb: %s"%(self.norm_char_emb))
        print("     Train  file directory: %s"%(self.train_dir))
        print("     Dev    file directory: %s"%(self.dev_dir))
        print("     Test   file directory: %s"%(self.test_dir))


        print("     Model  file directory: %s"%(self.model_dir))


        print("     Train instance number: %s"%(len(self.train_texts)))
        print("     Dev   instance number: %s"%(len(self.dev_texts)))
        print("     Test  instance number: %s"%(len(self.test_texts)))

        print("     FEATURE num: %s"%(self.feature_num))
        for idx in range(self.feature_num):
            print("         Fe: %s  alphabet  size: %s"%(self.feature_alphabets[idx].name, self.feature_alphabet_sizes[idx]))
            print("         Fe: %s  embedding  dir: %s"%(self.feature_alphabets[idx].name, self.feature_emb_dirs[idx]))
            print("         Fe: %s  embedding size: %s"%(self.feature_alphabets[idx].name, self.feature_emb_dims[idx]))
            print("         Fe: %s  norm       emb: %s"%(self.feature_alphabets[idx].name, self.norm_feature_embs[idx]))
        # for k, v in self.feat_config.items():
        #     print("         Feature: %s, size %s, norm %s, dir %s"%(k, v['emb_size'], v['emb_norm'], v['emb_dir']))

        print(" "+"++"*20)
        print(" Model Network:")
        print("     Model        use_crf: %s"%(self.use_crf))
        print("     Model word extractor: %s"%(self.word_feature_extractor))
        print("     Model       use_char: %s"%(self.use_char))
        if self.use_char:
            print("     Model char extractor: %s"%(self.char_feature_extractor))
            print("     Model char_hidden_dim: %s"%(self.HP_char_hidden_dim))
        print(" "+"++"*20)
        print(" Training:")
        print("     Optimizer: %s"%(self.optimizer))
        print("     Iteration: %s"%(self.HP_iteration))
        print("     BatchSize: %s"%(self.HP_batch_size))
        print("     Average  batch   loss: %s"%(self.average_batch_loss))

        print(" "+"++"*20)
        print(" Hyperparameters:")
        
        print("     Hyper              lr: %s"%(self.HP_lr))
        print("     Hyper        lr_decay: %s"%(self.HP_lr_decay))
        print("     Hyper         HP_clip: %s"%(self.HP_clip))
        print("     Hyper        momentum: %s"%(self.HP_momentum))
        print("     Hyper              l2: %s"%(self.HP_l2))
        print("     Hyper      hidden_dim: %s"%(self.HP_hidden_dim))
        print("     Hyper         dropout: %s"%(self.HP_dropout))
        print("     Hyper      lstm_layer: %s"%(self.HP_lstm_layer))
        print("     Hyper          bilstm: %s"%(self.HP_bilstm))
        print("     Hyper             GPU: %s"%(self.HP_gpu))
        print("     Hyper             NBEST: %s"%(self.nbest))

        print(" " + "++" * 20)
        print(" Both:")

        print("     full data: %s" % (self.full_data))
        print("     Tune  word embeddings: %s" % (self.tune_wordemb))

        print(" "+"++"*20)
        print(" Relation:")

        print("     Pretrain directory: %s" % (self.pretrain))
        print("     max sequence length: %s" % (self.max_seq_len))
        print("     pad index: %s" % (self.pad_idx))
        print("     sentence window: %s" % (self.sent_window))
        print("     Output directory: %s" % (self.output))
        print("     The ratio using negative instnaces 0~1: %s" % (self.unk_ratio))
        print("     Size of seqeuence feature representation: %s" % (self.seq_feature_size))
        print("     Iteration for relation training: %s" % (self.max_epoch))
        print("     feature_extractor: %s" % (self.feature_extractor))

        print("     RE FEATURE num: %s"%(self.re_feature_num))
        for idx in range(self.re_feature_num):
            print("         Fe: %s  alphabet  size: %s"%(self.re_feature_alphabets[idx].name, self.re_feature_alphabet_sizes[idx]))
            print("         Fe: %s  embedding  dir: %s"%(self.re_feature_alphabets[idx].name, self.re_feature_emb_dirs[idx]))
            print("         Fe: %s  embedding size: %s"%(self.re_feature_alphabets[idx].name, self.re_feature_emb_dims[idx]))
            print("         Fe: %s  norm       emb: %s"%(self.re_feature_alphabets[idx].name, self.re_norm_feature_embs[idx]))

        print("     RE Train instance number: %s"%(len(self.re_train_Y)))
        print("     RE Dev   instance number: %s"%(len(self.re_dev_Y)))
        print("     RE Test  instance number: %s"%(len(self.re_test_Y)))

        print("DATA SUMMARY END.")
        print("++"*50)
        sys.stdout.flush()


    def initial_feature_alphabets(self, input_file):
        items = open(input_file,'r').readline().strip('\n').split()
        total_column = len(items)
        if total_column > 2:
            id = 0
            for idx in range(1, total_column-1):
                feature_prefix = items[idx].split(']',1)[0]+"]"
                self.feature_alphabets.append(Alphabet(feature_prefix))
                self.feature_name.append(feature_prefix)
                self.feature_name2id[feature_prefix] = id
                id += 1
                print "Find feature: ", feature_prefix 
        self.feature_num = len(self.feature_alphabets)
        self.pretrain_feature_embeddings = [None]*self.feature_num
        self.feature_emb_dims = [20]*self.feature_num
        self.feature_emb_dirs = [None]*self.feature_num 
        self.norm_feature_embs = [False]*self.feature_num
        self.feature_alphabet_sizes = [0]*self.feature_num
        if self.feat_config:
            for idx in range(self.feature_num):
                if self.feature_name[idx] in self.feat_config:
                    self.feature_emb_dims[idx] = self.feat_config[self.feature_name[idx]]['emb_size']
                    self.feature_emb_dirs[idx] = self.feat_config[self.feature_name[idx]]['emb_dir']
                    self.norm_feature_embs[idx] = self.feat_config[self.feature_name[idx]]['emb_norm']
        # exit(0)


    def build_alphabet(self, input_file):
        in_lines = open(input_file,'r').readlines()
        for line in in_lines:
            if len(line) > 2:
                pairs = line.strip().split()
                word = pairs[0].decode('utf-8')
                if self.number_normalized:
                    word = normalize_word(word)
                label = pairs[-1]
                self.label_alphabet.add(label)
                self.word_alphabet.add(word)
                ## build feature alphabet 
                for idx in range(self.feature_num):
                    feat_idx = pairs[idx+1].split(']',1)[-1]
                    self.feature_alphabets[idx].add(feat_idx)
                for char in word:
                    self.char_alphabet.add(char)
        self.word_alphabet_size = self.word_alphabet.size()
        self.char_alphabet_size = self.char_alphabet.size()
        self.label_alphabet_size = self.label_alphabet.size()
        for idx in range(self.feature_num):
            self.feature_alphabet_sizes[idx] = self.feature_alphabets[idx].size()
        startS = False
        startB = False
        for label,_ in self.label_alphabet.iteritems():
            if "S-" in label.upper():
                startS = True
            elif "B-" in label.upper():
                startB = True
        if startB:
            if startS:
                self.tagScheme = "BMES"
            else:
                self.tagScheme = "BIO"


    def fix_alphabet(self):
        self.word_alphabet.close()
        self.char_alphabet.close()
        self.label_alphabet.close() 
        for idx in range(self.feature_num):
            self.feature_alphabets[idx].close()

    def initial_re_feature_alphabets(self):
        id = 0
        for k, v in self.re_feat_config.items():
            self.re_feature_alphabets.append(Alphabet(k))
            self.re_feature_name.append(k)
            self.re_feature_name2id[k] = id
            id += 1

        self.re_feature_num = len(self.re_feature_alphabets)
        self.re_pretrain_feature_embeddings = [None]*self.re_feature_num
        self.re_feature_emb_dims = [20]*self.re_feature_num
        self.re_feature_emb_dirs = [None]*self.re_feature_num
        self.re_norm_feature_embs = [False]*self.re_feature_num
        self.re_feature_alphabet_sizes = [0]*self.re_feature_num
        if self.re_feat_config:
            for idx in range(self.re_feature_num):
                if self.re_feature_name[idx] in self.re_feat_config:
                    self.re_feature_emb_dims[idx] = self.re_feat_config[self.re_feature_name[idx]]['emb_size']
                    self.re_feature_emb_dirs[idx] = self.re_feat_config[self.re_feature_name[idx]]['emb_dir']
                    self.re_norm_feature_embs[idx] = self.re_feat_config[self.re_feature_name[idx]]['emb_norm']


    def build_re_feature_alphabets(self, tokens, entities, relations):

        entity_type_alphabet = self.re_feature_alphabets[self.re_feature_name2id['[ENTITY_TYPE]']]
        entity_alphabet = self.re_feature_alphabets[self.re_feature_name2id['[ENTITY]']]
        relation_alphabet = self.re_feature_alphabets[self.re_feature_name2id['[RELATION]']]
        token_num_alphabet = self.re_feature_alphabets[self.re_feature_name2id['[TOKEN_NUM]']]
        entity_num_alphabet = self.re_feature_alphabets[self.re_feature_name2id['[ENTITY_NUM]']]
        position_alphabet = self.re_feature_alphabets[self.re_feature_name2id['[POSITION]']]

        for i, doc_token in enumerate(tokens):

            doc_entity = entities[i]
            doc_relation = relations[i]

            sent_idx = 0
            sentence = doc_token[(doc_token['sent_idx'] == sent_idx)]
            while sentence.shape[0] != 0:

                entities_in_sentence = doc_entity[(doc_entity['sent_idx'] == sent_idx)]
                for _, entity in entities_in_sentence.iterrows():
                    entity_type_alphabet.add(entity['type'])
                    tk_idx = entity['tf_start']
                    while tk_idx <= entity['tf_end']:
                        entity_alphabet.add(
                            my_utils1.normalizeWord(sentence.iloc[tk_idx, 0]))  # assume 'text' is in 0 column
                        tk_idx += 1

                sent_idx += 1
                sentence = doc_token[(doc_token['sent_idx'] == sent_idx)]

            for _, relation in doc_relation.iterrows():
                relation_alphabet.add(relation['type'])


        for i in range(data.max_seq_len):
            token_num_alphabet.add(i)
            entity_num_alphabet.add(i)
            position_alphabet.add(i)
            position_alphabet.add(-i)


        for idx in range(self.re_feature_num):
            self.re_feature_alphabet_sizes[idx] = self.re_feature_alphabets[idx].size()


    def fix_re_alphabet(self):
        for alphabet in self.re_feature_alphabets:
            alphabet.close()


    def build_pretrain_emb(self):
        if self.word_emb_dir:
            print("Load pretrained word embedding, norm: %s, dir: %s"%(self.norm_word_emb, self.word_emb_dir))
            self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(self.word_emb_dir, self.word_alphabet, self.word_emb_dim, self.norm_word_emb)
        if self.char_emb_dir:
            print("Load pretrained char embedding, norm: %s, dir: %s"%(self.norm_char_emb, self.char_emb_dir))
            self.pretrain_char_embedding, self.char_emb_dim = build_pretrain_embedding(self.char_emb_dir, self.char_alphabet, self.char_emb_dim, self.norm_char_emb)
        for idx in range(self.feature_num):
            if self.feature_emb_dirs[idx]:
                print("Load pretrained feature %s embedding:, norm: %s, dir: %s"%(self.feature_name[idx], self.norm_feature_embs[idx], self.feature_emb_dirs[idx]))
                self.pretrain_feature_embeddings[idx], self.feature_emb_dims[idx] = build_pretrain_embedding(self.feature_emb_dirs[idx], self.feature_alphabets[idx], self.feature_emb_dims[idx], self.norm_feature_embs[idx])

    def build_re_pretrain_emb(self):
        for idx in range(self.re_feature_num):
            if self.re_feature_emb_dirs[idx]:
                print("Load pretrained re feature %s embedding:, norm: %s, dir: %s" % (self.re_feature_name[idx], self.re_norm_feature_embs[idx], self.re_feature_emb_dirs[idx]))
                self.re_pretrain_feature_embeddings[idx], self.re_feature_emb_dims[idx] = build_pretrain_embedding(
                    self.re_feature_emb_dirs[idx], self.re_feature_alphabets[idx], self.re_feature_emb_dims[idx],
                    self.re_norm_feature_embs[idx])

    def generate_instance(self, name, input_file):
        self.fix_alphabet()
        if name == "train":
            self.train_texts, self.train_Ids = read_instance(input_file, self.word_alphabet, self.char_alphabet, self.feature_alphabets, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "dev":
            self.dev_texts, self.dev_Ids = read_instance(input_file, self.word_alphabet, self.char_alphabet, self.feature_alphabets, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "test":
            self.test_texts, self.test_Ids = read_instance(input_file, self.word_alphabet, self.char_alphabet, self.feature_alphabets, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
        else:
            print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))



    def generate_re_instance(self, name, tokens, entities, relations, names):
        self.fix_re_alphabet()
        if name == "train":
            self.re_train_X, self.re_train_Y = relation_extraction.getRelationInstance2(tokens, entities, relations, names, self)
        elif name == "dev":
            self.re_dev_X, self.re_dev_Y = relation_extraction.getRelationInstance2(tokens, entities, relations, names, self)
        elif name == "test":
            self.re_test_X, self.re_test_Y = relation_extraction.getRelationInstance2(tokens, entities, relations, names, self)
        else:
            print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))


    def load(self,data_file):
        f = open(data_file, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)

    def save(self,save_file):
        f = open(save_file, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()




    def read_config(self,config_file):
        config = config_file_to_dict(config_file)
        ## read data:
        the_item = 'train_dir'
        if the_item in config:
            self.train_dir = config[the_item]
        the_item = 'dev_dir'
        if the_item in config:
            self.dev_dir = config[the_item]
        the_item = 'test_dir'
        if the_item in config:
            self.test_dir = config[the_item]


        the_item = 'model_dir'
        if the_item in config:
            self.model_dir = config[the_item]


        the_item = 'word_emb_dir'
        if the_item in config:
            self.word_emb_dir = config[the_item]
        the_item = 'char_emb_dir'
        if the_item in config:
            self.char_emb_dir = config[the_item]


        the_item = 'MAX_SENTENCE_LENGTH'
        if the_item in config:
            self.MAX_SENTENCE_LENGTH = int(config[the_item])
        the_item = 'MAX_WORD_LENGTH'
        if the_item in config:
            self.MAX_WORD_LENGTH = int(config[the_item])

        the_item = 'norm_word_emb'
        if the_item in config:
            self.norm_word_emb = str2bool(config[the_item])
        the_item = 'norm_char_emb'
        if the_item in config:
            self.norm_char_emb = str2bool(config[the_item])
        the_item = 'number_normalized'
        if the_item in config:
            self.number_normalized = str2bool(config[the_item])


        the_item = 'seg'
        if the_item in config:
            self.seg = str2bool(config[the_item])
        the_item = 'word_emb_dim'
        if the_item in config:
            self.word_emb_dim = int(config[the_item])
        the_item = 'char_emb_dim'
        if the_item in config:
            self.char_emb_dim = int(config[the_item])

        ## read network:
        the_item = 'use_crf'
        if the_item in config:
            self.use_crf = str2bool(config[the_item])
        the_item = 'use_char'
        if the_item in config:
            self.use_char = str2bool(config[the_item])
        the_item = 'word_seq_feature'
        if the_item in config:
            self.word_feature_extractor = config[the_item]
        the_item = 'char_seq_feature'
        if the_item in config:
            self.char_feature_extractor = config[the_item]
        the_item = 'nbest'
        if the_item in config:
            self.nbest = int(config[the_item])

        the_item = 'feature'
        if the_item in config:
            self.feat_config = config[the_item] ## feat_config is a dict 






        ## read training setting:
        the_item = 'optimizer'
        if the_item in config:
            self.optimizer = config[the_item]
        the_item = 'ave_batch_loss'
        if the_item in config:
            self.average_batch_loss = str2bool(config[the_item])


        ## read Hyperparameters:
        the_item = 'cnn_layer'
        if the_item in config:
            self.HP_cnn_layer = int(config[the_item])
        the_item = 'iteration'
        if the_item in config:
            self.HP_iteration = int(config[the_item])
        the_item = 'batch_size'
        if the_item in config:
            self.HP_batch_size = int(config[the_item])

        the_item = 'char_hidden_dim'
        if the_item in config:
            self.HP_char_hidden_dim = int(config[the_item])
        the_item = 'hidden_dim'
        if the_item in config:
            self.HP_hidden_dim = int(config[the_item])
        the_item = 'dropout'
        if the_item in config:
            self.HP_dropout = float(config[the_item])
        the_item = 'lstm_layer'
        if the_item in config:
            self.HP_lstm_layer = int(config[the_item])
        the_item = 'bilstm'
        if the_item in config:
            self.HP_bilstm = str2bool(config[the_item])

        the_item = 'gpu'
        if the_item in config:
            self.HP_gpu = int(config[the_item])
        the_item = 'learning_rate'
        if the_item in config:
            self.HP_lr = float(config[the_item])
        the_item = 'lr_decay'
        if the_item in config:
            self.HP_lr_decay = float(config[the_item])
        the_item = 'clip'
        if the_item in config:
            self.HP_clip = float(config[the_item])
        the_item = 'momentum'
        if the_item in config:
            self.HP_momentum = float(config[the_item])
        the_item = 'l2'
        if the_item in config:
            self.HP_l2 = float(config[the_item])

        # both
        the_item = 'full_data'
        if the_item in config:
            self.full_data = str2bool(config[the_item])

        the_item = 'tune_wordemb'
        if the_item in config:
            self.tune_wordemb = str2bool(config[the_item])

        # relation
        the_item = 'pretrain'
        if the_item in config:
            self.pretrain = config[the_item]

        the_item = 'max_seq_len'
        if the_item in config:
            self.max_seq_len = int(config[the_item])

        the_item = 'pad_idx'
        if the_item in config:
            self.pad_idx = int(config[the_item])

        the_item = 'sent_window'
        if the_item in config:
            self.sent_window = int(config[the_item])

        the_item = 'output'
        if the_item in config:
            self.output = config[the_item]

        the_item = 'unk_ratio'
        if the_item in config:
            self.unk_ratio = float(config[the_item])

        the_item = 'seq_feature_size'
        if the_item in config:
            self.seq_feature_size = int(config[the_item])

        the_item = 'max_epoch'
        if the_item in config:
            self.max_epoch = int(config[the_item])

        the_item = 'feature_extractor'
        if the_item in config:
            self.feature_extractor = config[the_item]

        the_item = 're_feature'
        if the_item in config:
            self.re_feat_config = config[the_item] ## feat_config is a dict
コード例 #3
0
ファイル: data.py プロジェクト: NLP1502/NLP
class Data:
    def __init__(self):
        self.substring_names = ['word', 'pos', 'char', 'bpe', 'word-pos']
        self.substring_maxlen = 10

        self.MAX_SENTENCE_LENGTH = 250
        self.MAX_WORD_LENGTH = -1
        self.number_normalized = True
        self.norm_word_emb = False
        self.norm_char_emb = False
        self.norm_trans_emb = False
        self.word_alphabet = Alphabet('word')
        self.char_alphabet = Alphabet('character')

        self.translation_alphabet = Alphabet('translation')
        self.translation_id_format = {}

        self.feature_names = []
        self.feature_alphabets = []
        self.feature_num = len(self.feature_alphabets)
        self.feat_config = None

        self.label_alphabet = Alphabet('label', True)
        self.tagScheme = "NoSeg"  ## BMES/BIO

        self.seg = True
        ###
        self.task_name = None

        ### I/O
        self.data_bin_dir = None
        self.train_dir = None
        self.dev_dir = None
        self.test_dir = None
        self.raw_dir = None
        self.middle_dir = None
        self.viterbi_inputs_model_name = None

        self.trans_dir = None

        self.decode_dir = None
        self.model_dir = None  ## model save  file
        self.load_model_dir = None  ## model load file

        self.word_emb_dir = None
        self.char_emb_dir = None
        self.trans_embed_dir = None
        self.typeinfo_dir = None

        self.feature_emb_dirs = []

        self.train_texts = []
        self.dev_texts = []
        self.test_texts = []
        self.raw_texts = []

        self.train_Ids = []
        self.dev_Ids = []
        self.test_Ids = []
        self.raw_Ids = []

        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None
        self.pretrain_trans_embedding = None
        self.pretrain_feature_embeddings = []

        self.label_size = 0
        self.word_alphabet_size = 0
        self.char_alphabet_size = 0
        self.label_alphabet_size = 0
        self.trans_alphabet_size = 0

        self.feature_alphabet_sizes = []
        self.feature_emb_dims = []
        self.norm_feature_embs = []
        self.word_emb_dim = 50
        self.char_emb_dim = 30
        self.trans_emb_dim = 100

        ###Classification
        ## Dataset Plus
        self.substring_dir = None
        self.bpe_emb_dir = None
        self.pos_emb_dir = None
        self.pretrain_bpe_embedding = None
        self.pretrain_pos_embedding = None
        self.bpe_emb_dim = 30
        self.pos_emb_dim = 30
        self.bpe_alphabet_size = 0
        self.pos_alphabet_size = 0
        self.norm_bpe_emb = False
        self.norm_pos_emb = False
        self.bpe_texts = []
        self.bpe_Ids = []
        self.pos_texts = []
        self.pos_Ids = []
        self.label_size = 0
        self.substring_train_texts = None
        self.substring_train_Ids = None
        self.substring_dev_texts = None
        self.substring_dev_Ids = None
        self.substring_test_texts = None
        self.substring_test_Ids = None
        self.substring_label_alphabet = Alphabet('substring_label', True)

        ###Networks
        self.word_feature_extractor = "LSTM"  # "LSTM"/"CNN"/"GRU"/
        self.use_char = True
        self.char_seq_feature = "CNN"  # "LSTM"/"CNN"/"GRU"/None
        self.use_trans = False
        self.use_crf = True
        self.nbest = None
        self.use_mapping = False
        self.mapping_func = None  # tanh or sigmoid

        # Training
        self.save_model = True
        self.state_training_name = 'default'
        self.average_batch_loss = False
        self.optimizer = "SGD"  # "SGD"/"Adam"
        self.status = "train"
        self.show_loss_per_batch = 100
        # Hyperparameters
        self.seed_num = None
        self.cnn_layer = 4
        self.iteration = 100
        self.batch_size = 10
        self.char_hidden_dim = 50
        self.trans_hidden_dim = 50
        self.hidden_dim = 200
        self.dropout = 0.5
        self.lstm_layer = 1
        self.bilstm = True

        self.gpu = False
        self.lr = 0.015
        self.lr_decay = 0.05
        self.clip = None
        self.momentum = 0
        self.l2 = 1e-8

        # circul
        self.circul_time = 4
        self.circul_deepth = 2
        self.circul_gather_output_mode = "concat"

        # decode prepare
        self.decode_prepare_mode = 'example'

    def init_substring_instance(self):
        len_names = len(self.substring_names)
        self.substring_train_texts = [[[]
                                       for _ in range(self.substring_maxlen)]
                                      for _ in range(len_names)]
        self.substring_train_Ids = [[[] for _ in range(self.substring_maxlen)]
                                    for _ in range(len_names)]
        self.substring_dev_texts = [[[] for _ in range(self.substring_maxlen)]
                                    for _ in range(len_names)]
        self.substring_dev_Ids = [[[] for _ in range(self.substring_maxlen)]
                                  for _ in range(len_names)]
        self.substring_test_texts = [[[] for _ in range(self.substring_maxlen)]
                                     for _ in range(len_names)]
        self.substring_test_Ids = [[[] for _ in range(self.substring_maxlen)]
                                   for _ in range(len_names)]

    def show_data_summary(self):
        print("++" * 50)
        print("DATA SUMMARY START:")
        print(" I/O:")
        print("     Tag          scheme: %s" % (self.tagScheme))
        print("     MAX SENTENCE LENGTH: %s" % (self.MAX_SENTENCE_LENGTH))
        print("     MAX   WORD   LENGTH: %s" % (self.MAX_WORD_LENGTH))
        print("     Number   normalized: %s" % (self.number_normalized))
        print("     Word  alphabet size: %s" % (self.word_alphabet_size))
        print("     Char  alphabet size: %s" % (self.char_alphabet_size))
        print("     Label alphabet size: %s" % (self.label_alphabet_size))
        print("     Trans alphabet size: %s" % (self.trans_alphabet_size))
        print("     Word embedding  dir: %s" % (self.word_emb_dir))
        print("     Char embedding  dir: %s" % (self.char_emb_dir))
        print("     Tran embedding  dir: %s" % (self.trans_embed_dir))
        print("     Word embedding size: %s" % (self.word_emb_dim))
        print("     Char embedding size: %s" % (self.char_emb_dim))
        print("     Tran embedding size: %s" % (self.trans_emb_dim))
        print("     Norm   word     emb: %s" % (self.norm_word_emb))
        print("     Norm   char     emb: %s" % (self.norm_char_emb))
        print("     Norm   tran     emb: %s" % (self.norm_trans_emb))
        print("++" * 50)
        print("   task name: %s" % (self.task_name))
        print("++" * 50)
        print("   Data bin file directory: %s" % (self.data_bin_dir))
        print("     Train  file directory: %s" % (self.train_dir))
        print("     Dev    file directory: %s" % (self.dev_dir))
        print("     Test   file directory: %s" % (self.test_dir))
        print("     Raw    file directory: %s" % (self.raw_dir))
        print("     Middle file directory: %s" % (self.middle_dir))
        print(" viterbi inputs model name: %s" %
              (self.viterbi_inputs_model_name))
        if self.typeinfo_dir:
            print("     typeinfo    directory: %s" % (self.typeinfo_dir))
        print("     Model  file directory: %s" % (self.model_dir))
        print("     Loadmodel   directory: %s" % (self.load_model_dir))
        print("     Decode file directory: %s" % (self.decode_dir))
        print("     Train instance number: %s" % (len(self.train_texts)))
        print("     Dev   instance number: %s" % (len(self.dev_texts)))
        print("     Test  instance number: %s" % (len(self.test_texts)))
        print("     Raw   instance number: %s" % (len(self.raw_texts)))
        print("     FEATURE num: %s" % (self.feature_num))
        for idx in range(self.feature_num):
            print("         Fe: %s  alphabet  size: %s" %
                  (self.feature_alphabets[idx].name,
                   self.feature_alphabet_sizes[idx]))
            print(
                "         Fe: %s  embedding  dir: %s" %
                (self.feature_alphabets[idx].name, self.feature_emb_dirs[idx]))
            print(
                "         Fe: %s  embedding size: %s" %
                (self.feature_alphabets[idx].name, self.feature_emb_dims[idx]))
            print("         Fe: %s  norm       emb: %s" %
                  (self.feature_alphabets[idx].name,
                   self.norm_feature_embs[idx]))
        print(" " + "++" * 20)
        print(" Model Network:")
        print("     Model        use_crf: %s" % (self.use_crf))
        print("     Model word extractor: %s" % (self.word_feature_extractor))
        print("     Model       use_char: %s" % (self.use_char))
        if self.use_char:
            print("     Model char_seq_feature: %s" % (self.char_seq_feature))
            print("     Model char_hidden_dim: %s" % (self.char_hidden_dim))
        if self.use_trans:
            print("     Model trans_hidden_dim: %s" % (self.trans_hidden_dim))
        if self.use_mapping:
            print("     Model mapping function: %s" % (self.mapping_func))
        print(" " + "++" * 20)
        print(" Training:")
        print("     show_loss_per_batch: %s" % (self.show_loss_per_batch))
        print("     save_model: %s" % (self.save_model))
        print("     state_training_name: %s" % (self.state_training_name))
        print("     Optimizer: %s" % (self.optimizer))
        print("     Iteration: %s" % (self.iteration))
        print("     BatchSize: %s" % (self.batch_size))
        print("     Average  batch   loss: %s" % (self.average_batch_loss))

        print(" " + "++" * 20)
        print(" Hyperparameters:")

        print("     Hyper        seed_num: %s" % (self.seed_num))
        print("     Hyper              lr: %s" % (self.lr))
        print("     Hyper        lr_decay: %s" % (self.lr_decay))
        print("     Hyper            clip: %s" % (self.clip))
        print("     Hyper        momentum: %s" % (self.momentum))
        print("     Hyper              l2: %s" % (self.l2))
        print("     Hyper      hidden_dim: %s" % (self.hidden_dim))
        print("     Hyper         dropout: %s" % (self.dropout))
        print("     Hyper      lstm_layer: %s" % (self.lstm_layer))
        print("     Hyper          bilstm: %s" % (self.bilstm))
        print("     Hyper             GPU: %s" % (self.gpu))
        print("DATA SUMMARY END.")
        print("++" * 50)

        print("      substring dir : %s" % (self.substring_dir))
        print("    bpe_emb_dir dir : %s" % (self.bpe_emb_dir))
        print("    pos_emb_dir dir : %s" % (self.pos_emb_dir))
        print("++" * 50)

        print("      circul time   : %s" % (self.circul_time))
        print("      circul deepth : %s" % (self.circul_deepth))
        print(" gather output mode : %s" % (self.circul_gather_output_mode))
        print("++" * 50)

        print(" decode prepare mode : %s" % (self.decode_prepare_mode))
        print("++" * 50)

        sys.stdout.flush()

    def make_substring_label_alphabet(self):
        for label in self.label_alphabet.instances:
            label = label.split('-')[-1]
            self.substring_label_alphabet.add(label)
        self.substring_label_alphabet.close()

    def initial_feature_alphabets(self):
        items = open(self.train_dir, 'r').readline().strip('\n').split()
        total_column = len(items)
        if total_column > 2:
            for idx in range(1, total_column - 1):
                feature_prefix = 'feature_' + str(idx)
                self.feature_alphabets.append(Alphabet(feature_prefix))
                self.feature_names.append(feature_prefix)
                print "Find feature: ", feature_prefix
        self.feature_num = len(self.feature_alphabets)
        self.pretrain_feature_embeddings = [None] * self.feature_num
        self.feature_emb_dims = [20] * self.feature_num
        self.feature_emb_dirs = [None] * self.feature_num
        self.norm_feature_embs = [False] * self.feature_num
        self.feature_alphabet_sizes = [0] * self.feature_num
        if self.feat_config:
            for idx in range(self.feature_num):
                self.feature_emb_dims[idx] = self.feat_config[
                    self.feature_names[idx]]['emb_size']
                self.feature_emb_dirs[idx] = self.feat_config[
                    self.feature_names[idx]]['emb_dir']
                self.norm_feature_embs[idx] = self.feat_config[
                    self.feature_names[idx]]['emb_norm']
        # exit(0)

    def build_alphabet(self, input_file):
        in_lines = open(input_file, 'r').readlines()
        for line in in_lines:
            if len(line) > 2:
                pairs = line.strip().split()
                word = pairs[0].decode('windows-1252')
                # word = pairs[0].decode('utf-8')
                if self.number_normalized:
                    word = normalize_word(word)
                label = pairs[-1]
                self.label_alphabet.add(label)
                self.word_alphabet.add(word)
                ## build feature alphabet
                for idx in range(self.feature_num):
                    feat_idx = pairs[idx + 1].split(']', 1)[-1]
                    self.feature_alphabets[idx].add(feat_idx)
                for char in word:
                    self.char_alphabet.add(char)
        self.word_alphabet_size = self.word_alphabet.size()
        self.char_alphabet_size = self.char_alphabet.size()
        self.label_alphabet_size = self.label_alphabet.size()
        for idx in range(self.feature_num):
            self.feature_alphabet_sizes[idx] = self.feature_alphabets[
                idx].size()
        startS = False
        startB = False
        for label, _ in self.label_alphabet.iteritems():
            if "S-" in label.upper():
                startS = True
            elif "B-" in label.upper():
                startB = True
        if startB:
            if startS:
                self.tagScheme = "BMES"
            else:
                self.tagScheme = "BIO"

    def build_alphabet_substring(self, input_file_dir, substring_file_prefix):
        ## will not read lables
        input_files = os.listdir(input_file_dir)
        print input_files
        for input_file in input_files:
            plus_feature = ''
            input_file_name = os.path.split(input_file)[1]
            if input_file_name.split('.')[0] != substring_file_prefix:
                continue
            if 'bpe' in input_file_name:
                plus_feature = 'bpe'
            elif 'word' in input_file_name:
                plus_feature = 'word'
            if plus_feature == '':
                continue
            in_lines = open(input_file_dir + input_file, 'r').readlines()
            for line in in_lines:
                if len(line.strip()) > 0:
                    pairs = line.strip().split('\t')
                    words = pairs[0].decode('windows-1252')
                    # word = pairs[0].decode('utf-8')
                    if self.number_normalized:
                        words = normalize_word(words)
                    labels = pairs[-1]
                    for word in words.split():
                        self.word_alphabet.add(word)
                        for char in word:
                            self.char_alphabet.add(char)
            self.word_alphabet_size = self.word_alphabet.size()
            self.char_alphabet_size = self.char_alphabet.size()

    def fix_alphabet(self):
        self.word_alphabet.close()
        self.char_alphabet.close()
        self.label_alphabet.close()
        self.translation_alphabet.close()
        for idx in range(self.feature_num):
            self.feature_alphabets[idx].close()

    def build_pretrain_emb(self):
        if self.word_emb_dir:
            print("Load pretrained word embedding, norm: %s, dir: %s" %
                  (self.norm_word_emb, self.word_emb_dir))
            self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(
                self.word_emb_dir, self.word_alphabet, self.word_emb_dim,
                self.norm_word_emb)

            if self.typeinfo_dir:
                type_info_matrix = []
                with codecs.open(self.typeinfo_dir, 'r') as typeinfo_file:
                    type_info_lines = typeinfo_file.readlines()
                    for line in type_info_lines:
                        line = line.rstrip().split()
                        for i, _ in enumerate(line):
                            line[i] = float(line[i])
                        line = np.array(line)
                        type_info_matrix.append(line)

                print(
                    "Caculate type info distribution,and concate word and type......"
                )
                cos_res = []
                for i, word_embed in enumerate(self.pretrain_word_embedding):
                    word_type_info = []
                    if i == 0:
                        word_type_info = np.random.random(
                            size=len(type_info_matrix))
                        cos_res.append(word_type_info)
                    else:
                        for type_info in type_info_matrix:
                            cos_sim = 1 - spatial.distance.cosine(
                                word_embed, type_info)
                            word_type_info.append(cos_sim)
                        cos_res.append(word_type_info)
                cos_res = np.array(cos_res)
                cos_res = sigmoid(cos_res)
                self.pretrain_word_embedding = np.concatenate(
                    [self.pretrain_word_embedding, cos_res], axis=1)
                print "type info length:{}".format(len(type_info_matrix))
                self.word_emb_dim += len(type_info_matrix)
                print "new word dim is :{}".format(self.word_emb_dim)

        if self.char_emb_dir:
            print("Load pretrained char embedding, norm: %s, dir: %s" %
                  (self.norm_char_emb, self.char_emb_dir))
            self.pretrain_char_embedding, self.char_emb_dim = build_pretrain_embedding(
                self.char_emb_dir, self.char_alphabet, self.char_emb_dim,
                self.norm_char_emb)
        if self.trans_embed_dir:
            print("Load pretrained trans embedding, norm: %s, dir: %s" %
                  (self.norm_trans_emb, self.trans_embed_dir))
            self.pretrain_trans_embedding, self.trans_emb_dim = build_chi_pretrain_embedding(
                self.trans_embed_dir, self.translation_alphabet,
                self.trans_emb_dim, self.norm_trans_emb)

        for idx in range(self.feature_num):
            if self.feature_emb_dirs[idx]:
                print(
                    "Load pretrained feature %s embedding:, norm: %s, dir: %s"
                    % (self.feature_name[idx], self.norm_feature_embs[idx],
                       self.feature_emb_dirs[idx]))
                self.pretrain_feature_embeddings[idx], self.feature_emb_dims[
                    idx] = build_pretrain_embedding(
                        self.feature_emb_dirs[idx],
                        self.feature_alphabets[idx],
                        self.feature_emb_dims[idx],
                        self.norm_feature_embs[idx])

    def generate_instance(self, name):
        self.fix_alphabet()
        if name == "train":
            self.train_texts, self.train_Ids = read_instance(
                self.train_dir, self.word_alphabet, self.char_alphabet,
                self.feature_alphabets, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH,
                self.translation_id_format)
        elif name == "dev":
            self.dev_texts, self.dev_Ids = read_instance(
                self.dev_dir, self.word_alphabet, self.char_alphabet,
                self.feature_alphabets, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH,
                self.translation_id_format)
        elif name == "test":
            self.test_texts, self.test_Ids = read_instance(
                self.test_dir, self.word_alphabet, self.char_alphabet,
                self.feature_alphabets, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH,
                self.translation_id_format)
        elif name == "raw":
            self.raw_texts, self.raw_Ids = read_instance(
                self.raw_dir, self.word_alphabet, self.char_alphabet,
                self.feature_alphabets, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH,
                self.translation_id_format)
        else:
            print(
                "Error: you can only generate train/dev/test instance! Illegal input:%s"
                % (name))

    def generate_instance_substring(self, substring_file_prefix):
        self.init_substring_instance()
        self.make_substring_label_alphabet()
        input_files = os.listdir(self.substring_dir)
        print input_files
        for input_file in input_files:
            input_file_name = os.path.split(input_file)[1]
            input_file_dir = os.path.join(self.substring_dir, input_file_name)
            input_file_name_split = input_file_name.split('.')
            if input_file_name_split[0] != substring_file_prefix:
                continue
            print('dealing %s' % (input_file_name))
            name = input_file_name_split[1]
            feature_name = input_file_name_split[2]
            f_l = int(input_file_name_split[-1][3:])  #feature_len

            if feature_name == 'word':
                alphabet = self.word_alphabet
            elif feature_name == 'char':
                alphabet = self.char_alphabet
            elif feature_name == 'pos':
                alphabet = self.feature_alphabets[0]
            elif feature_name == 'bpe':
                alphabet = self.feature_alphabets[1]

            s_f_id = self.substring_names.index(
                feature_name)  #substring_feature_id
            if name == "train":
                self.substring_train_texts[s_f_id][f_l], self.substring_train_Ids[s_f_id][f_l]\
                    = read_instance_substring(input_file_dir, alphabet, self.substring_label_alphabet, self.number_normalized)
            elif name == "testa":
                self.substring_dev_texts[s_f_id][f_l], self.substring_dev_Ids[s_f_id][f_l] \
                    = read_instance_substring(input_file_dir, alphabet, self.substring_label_alphabet, self.number_normalized)
            elif name == "testb":
                self.substring_test_texts[s_f_id][f_l], self.substring_test_Ids[s_f_id][f_l] \
                    = read_instance_substring(input_file_dir, alphabet, self.substring_label_alphabet, self.number_normalized)
            else:
                print(
                    "Error: you can only generate train/testa/testb instance! Illegal input:%s"
                    % (name))

    def write_decoded_results(self, predict_results, name):
        fout = open(self.decode_dir, 'w')
        sent_num = len(predict_results)
        content_list = []
        if name == 'raw':
            content_list = self.raw_texts
        elif name == 'test':
            content_list = self.test_texts
        elif name == 'dev':
            content_list = self.dev_texts
        elif name == 'train':
            content_list = self.train_texts
        else:
            print(
                "Error: illegal name during writing predict result, name should be within train/dev/test/raw !"
            )
        assert (sent_num == len(content_list))
        for idx in range(sent_num):
            sent_length = len(predict_results[idx])
            for idy in range(sent_length):
                ## content_list[idx] is a list with [word, char, label]
                fout.write(content_list[idx][0][idy].encode('utf-8') + " " +
                           predict_results[idx][idy] + '\n')
            fout.write('\n')
        fout.close()
        print("Predict %s result has been written into file. %s" %
              (name, self.decode_dir))

    def load(self, data_file):
        f = open(data_file, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)

    def save(self, save_file):
        f = open(save_file, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()

    def write_nbest_decoded_results(self, predict_results, pred_scores, name):
        ## predict_results : [whole_sent_num, nbest, each_sent_length]
        ## pred_scores: [whole_sent_num, nbest]
        fout = open(self.decode_dir, 'w')
        sent_num = len(predict_results)
        content_list = []
        if name == 'raw':
            content_list = self.raw_texts
        elif name == 'test':
            content_list = self.test_texts
        elif name == 'dev':
            content_list = self.dev_texts
        elif name == 'train':
            content_list = self.train_texts
        else:
            print(
                "Error: illegal name during writing predict result, name should be within train/dev/test/raw !"
            )
        assert (sent_num == len(content_list))
        assert (sent_num == len(pred_scores))
        for idx in range(sent_num):
            sent_length = len(predict_results[idx][0])
            nbest = len(predict_results[idx])
            score_string = "# "
            for idz in range(nbest):
                score_string += format(pred_scores[idx][idz], '.4f') + " "
            fout.write(score_string.strip() + "\n")

            for idy in range(sent_length):
                label_string = content_list[idx][0][idy].encode('utf-8') + " "
                for idz in range(nbest):
                    label_string += predict_results[idx][idz][idy] + " "
                label_string = label_string.strip() + "\n"
                fout.write(label_string)
            fout.write('\n')
        fout.close()
        print("Predict %s %s-best result has been written into file. %s" %
              (name, nbest, self.decode_dir))

    def read_config(self, config_file):
        config = config_file_to_dict(config_file)
        ## task:
        the_item = 'task_name'
        if the_item in config:
            self.task_name = config[the_item]

        ## read data:
        the_item = 'data_bin_dir'
        if the_item in config:
            self.data_bin_dir = config[the_item]
        the_item = 'train_dir'
        if the_item in config:
            self.train_dir = config[the_item]
        the_item = 'dev_dir'
        if the_item in config:
            self.dev_dir = config[the_item]
        the_item = 'test_dir'
        if the_item in config:
            self.test_dir = config[the_item]
        the_item = 'trans_dir'
        if the_item in config:
            self.trans_dir = config[the_item]
        the_item = 'middle_dir'
        if the_item in config:
            self.middle_dir = config[the_item]
        the_item = 'viterbi_inputs_model_name'
        if the_item in config:
            self.viterbi_inputs_model_name = config[the_item]

        the_item = 'substring_dir'
        if the_item in config:
            self.substring_dir = config[the_item]
        the_item = 'bpe_emb_dir'
        if the_item in config:
            self.bpe_emb_dir = config[the_item]
        the_item = 'pos_emb_dir'
        if the_item in config:
            self.pos_emb_dir = config[the_item]

        the_item = 'raw_dir'
        if the_item in config:
            self.raw_dir = config[the_item]
        the_item = 'decode_dir'
        if the_item in config:
            self.decode_dir = config[the_item]
        the_item = 'model_dir'
        if the_item in config:
            self.model_dir = config[the_item]
        the_item = 'load_model_dir'
        if the_item in config:
            self.load_model_dir = config[the_item]

        the_item = 'word_emb_dir'
        if the_item in config:
            self.word_emb_dir = config[the_item]
        the_item = 'char_emb_dir'
        if the_item in config:
            self.char_emb_dir = config[the_item]
        the_item = 'trans_embed_dir'
        if the_item in config:
            self.trans_embed_dir = config[the_item]
        the_item = 'typeinfo_dir'
        if the_item in config:
            self.typeinfo_dir = config[the_item]

        the_item = 'MAX_SENTENCE_LENGTH'
        if the_item in config:
            self.MAX_SENTENCE_LENGTH = int(config[the_item])
        the_item = 'MAX_WORD_LENGTH'
        if the_item in config:
            self.MAX_WORD_LENGTH = int(config[the_item])

        the_item = 'norm_word_emb'
        if the_item in config:
            self.norm_word_emb = str2bool(config[the_item])
        the_item = 'norm_char_emb'
        if the_item in config:
            self.norm_char_emb = str2bool(config[the_item])
        the_item = 'number_normalized'
        if the_item in config:
            self.number_normalized = str2bool(config[the_item])

        the_item = 'seg'
        if the_item in config:
            self.seg = str2bool(config[the_item])
        the_item = 'word_emb_dim'
        if the_item in config:
            self.word_emb_dim = int(config[the_item])
        the_item = 'char_emb_dim'
        if the_item in config:
            self.char_emb_dim = int(config[the_item])
        the_item = 'trans_emb_dim'
        if the_item in config:
            self.trans_emb_dim = int(config[the_item])

        ## read network:
        the_item = 'use_crf'
        if the_item in config:
            self.use_crf = str2bool(config[the_item])
        the_item = 'use_char'
        if the_item in config:
            self.use_char = str2bool(config[the_item])
        the_item = 'use_trans'
        if the_item in config:
            self.use_trans = str2bool(config[the_item])
        the_item = 'use_mapping'
        if the_item in config:
            self.use_mapping = str2bool(config[the_item])
        the_item = 'mapping_func'
        if the_item in config:
            self.mapping_func = config[the_item]
        the_item = 'word_seq_feature'
        if the_item in config:
            self.word_feature_extractor = config[the_item]
        the_item = 'char_seq_feature'
        if the_item in config:
            self.char_seq_feature = config[the_item]
        the_item = 'nbest'
        if the_item in config:
            self.nbest = int(config[the_item])

        the_item = 'feature'
        if the_item in config:
            self.feat_config = config[the_item]  ## feat_config is a dict

        ## read training setting:
        the_item = 'save_model'
        if the_item in config:
            self.save_model = str2bool(config[the_item])
        the_item = 'state_training_name'
        if the_item in config:
            self.state_training_name = config[the_item]
        the_item = 'optimizer'
        if the_item in config:
            self.optimizer = config[the_item]
        the_item = 'ave_batch_loss'
        if the_item in config:
            self.average_batch_loss = str2bool(config[the_item])
        the_item = 'status'
        if the_item in config:
            self.status = config[the_item]
        the_item = 'show_loss_per_batch'
        if the_item in config:
            self.show_loss_per_batch = int(config[the_item])

        ## read Hyperparameters:
        the_item = 'seed_num'
        if the_item in config:
            if config[the_item] != 'None':
                self.seed_num = int(config[the_item])
        the_item = 'cnn_layer'
        if the_item in config:
            self.cnn_layer = int(config[the_item])
        the_item = 'iteration'
        if the_item in config:
            self.iteration = int(config[the_item])
        the_item = 'batch_size'
        if the_item in config:
            self.batch_size = int(config[the_item])

        the_item = 'char_hidden_dim'
        if the_item in config:
            self.char_hidden_dim = int(config[the_item])

        the_item = 'trans_hidden_dim'
        if the_item in config:
            self.trans_hidden_dim = int(config[the_item])

        the_item = 'hidden_dim'
        if the_item in config:
            self.hidden_dim = int(config[the_item])
        the_item = 'dropout'
        if the_item in config:
            self.dropout = float(config[the_item])
        the_item = 'lstm_layer'
        if the_item in config:
            self.lstm_layer = int(config[the_item])
        the_item = 'bilstm'
        if the_item in config:
            self.bilstm = str2bool(config[the_item])

        the_item = 'gpu'
        if the_item in config:
            self.gpu = str2bool(config[the_item])
        the_item = 'learning_rate'
        if the_item in config:
            self.lr = float(config[the_item])
        the_item = 'lr_decay'
        if the_item in config:
            self.lr_decay = float(config[the_item])
        the_item = 'clip'
        if the_item in config:
            if config[the_item] == 'None':
                self.clip = None
            else:
                self.clip = float(config[the_item])
        the_item = 'momentum'
        if the_item in config:
            self.momentum = float(config[the_item])
        the_item = 'l2'
        if the_item in config:
            self.l2 = float(config[the_item])

        ###base2
        the_item = 'feature_name'
        if the_item in config:
            self.feature_name = config[the_item]
        the_item = 'feature_length'
        if the_item in config:
            self.feature_length = int(config[the_item])
        the_item = 'class_num'
        if the_item in config:
            self.class_num = int(config[the_item])
        the_item = 'feature_ans'
        if the_item in config:
            self.feature_ans = config[the_item]

        ###circul
        the_item = 'circul_time'
        if the_item in config:
            self.circul_time = config[the_item]
        the_item = 'circul_deepth'
        if the_item in config:
            self.circul_deepth = config[the_item]
        the_item = 'circul_gather_output_mode'
        if the_item in config:
            self.circul_gather_output_mode = config[the_item]

        ###decode_prepare
        the_item = 'decode_prepare_mode'
        if the_item in config:
            self.decode_prepare_mode = config[the_item]

    def read_arg(self, args):
        if args.task_name != None: self.task_name = args.task_name

        if args.data_bin_dir != None: self.data_bin_dir = args.data_bin_dir
        if args.train_dir != None: self.train_dir = args.train_dir
        if args.dev_dir != None: self.dev_dir = args.dev_dir
        if args.test_dir != None: self.test_dir = args.test_dir
        if args.trans_dir != None: self.trans_dir = args.trans_dir
        if args.word_emb_dir != None: self.word_emb_dir = args.word_emb_dir
        if args.trans_embed_dir != None:
            self.trans_embed_dir = args.trans_embed_dir
        if args.middle_dir != None: self.middle_dir = args.middle_dir
        if args.viterbi_inputs_model_name != None:
            self.viterbi_inputs_model_name = args.viterbi_inputs_model_name

        if args.substring_dir != None: self.substring_dir = args.substring_dir
        if args.bpe_emb_dir != None: self.bpe_emb_dir = args.bpe_emb_dir
        if args.pos_emb_dir != None: self.pos_emb_dir = args.pos_emb_dir

        if args.model_dir != None: self.model_dir = args.model_dir
        if args.norm_word_emb != None: self.norm_word_emb = args.norm_word_emb
        if args.norm_char_emb != None: self.norm_char_emb = args.norm_char_emb
        if args.word_emb_dim != None: self.word_emb_dim = args.word_emb_dim
        if args.char_emb_dim != None: self.char_emb_dim = args.char_emb_dim
        if args.trans_emb_dim != None: self.trans_emb_dim = args.trans_emb_dim

        if args.number_normalized != None:
            self.number_normalized = args.number_normalized
        if args.seg != None: self.seg = args.seg

        if args.use_crf != None: self.use_crf = args.use_crf
        if args.use_char != None: self.use_char = args.use_char
        if args.use_trans != None: self.use_trans = args.use_trans

        if args.word_seq_feature != None:
            self.word_seq_feature = args.word_seq_feature
        if args.char_seq_feature != None:
            self.char_seq_feature = args.char_seq_feature

        if args.nbest != None: self.nbest = args.nbest

        if args.status != None: self.status = args.status
        if args.state_training_name != None:
            self.state_training_name = args.state_training_name
        if args.save_model != None: self.save_model = args.save_model
        if args.optimizer != None: self.optimizer = args.optimizer
        if args.iteration != None: self.iteration = args.iteration
        if args.batch_size != None: self.batch_size = args.batch_size
        if args.ave_batch_loss != None:
            self.ave_batch_loss = args.ave_batch_loss
        if args.show_loss_per_batch != None:
            self.show_loss_per_batch = args.show_loss_per_batch

        if args.seed_num != None: self.seed_num = args.seed_num
        if args.cnn_layer != None: self.cnn_layer = args.cnn_layer
        if args.char_hidden_dim != None:
            self.char_hidden_dim = args.char_hidden_dim
        if args.trans_hidden_dim != None:
            self.trans_hidden_dim = args.trans_hidden_dim
        if args.hidden_dim != None: self.hidden_dim = args.hidden_dim
        if args.dropout != None: self.dropout = args.dropout
        if args.lstm_layer != None: self.lstm_layer = args.lstm_layer
        if args.bilstm != None: self.bilstm = args.bilstm
        if args.learning_rate != None: self.learning_rate = args.learning_rate
        if args.lr_decay != None: self.lr_decay = args.lr_decay
        if args.momentum != None: self.momentum = args.momentum
        if args.l2 != None: self.l2 = args.l2
        if args.gpu != None: self.gpu = args.gpu
        if args.clip != None: self.clip = args.clip

        ###base2
        if args.feature_name != None: self.feature_name = args.feature_name
        if args.feature_length != None:
            self.feature_length = args.feature_length
        if args.class_num != None: self.class_num = args.class_num
        if args.feature_ans != None:
            self.feature_ans = args.feature_ans

        ###circul
        if args.circul_time != None: self.circul_time = args.circul_time
        if args.circul_deepth != None: self.circul_deepth = args.circul_deepth
        if args.circul_gather_output_mode != None:
            self.circul_gather_output_mode = args.circul_gather_output_mode

        ###decode_prepare
        if args.decode_prepare_mode != None:
            self.decode_prepare_mode = args.decode_prepare_mode

    def build_translation_alphabet(self, trans_path):
        print("Creating translation alphabet......")
        with codecs.open(trans_path, 'r', "utf-8") as f:
            lines = f.readlines()
            for line in lines:
                if len(line.strip().split(":")) == 2:
                    temp = line.strip().split(":", 1)
                    words = temp[1].split()
                    for word in words:
                        self.translation_alphabet.add(word.strip())
        self.trans_alphabet_size = self.translation_alphabet.size()

    def build_translation_dict(self, trans_path):
        print("Creating Id to Id translation dictionary......")
        translation_id_format_temp = {}
        with codecs.open(trans_path, 'r', "utf-8") as f:
            lines = f.readlines()
            for line in lines:
                ids = []
                if len(line.strip().split(":", 1)) == 2:
                    temp = line.strip().split(":", 1)
                    word_id = self.word_alphabet.get_index(temp[0].strip())
                    translations = temp[1].split()
                    for translation in translations:
                        ids.append(
                            self.translation_alphabet.get_index(
                                translation.strip()))
                    if ids == []:
                        ids = [0]
                    translation_id_format_temp[word_id] = ids

        for word in self.word_alphabet.instances:
            if self.word_alphabet.get_index(
                    word) in translation_id_format_temp.keys():
                self.translation_id_format[self.word_alphabet.get_index(
                    word)] = translation_id_format_temp[
                        self.word_alphabet.get_index(word)]
            else:
                self.translation_id_format[self.word_alphabet.get_index(
                    word)] = [0]
コード例 #4
0
class Data:
    def __init__(self):
        self.MAX_SENTENCE_LENGTH = 250
        self.MAX_WORD_LENGTH = -1
        self.number_normalized = True
        self.norm_word_emb = False
        self.norm_char_emb = False
        self.word_alphabet = Alphabet('word')
        self.char_alphabet = Alphabet('character')
        # self.word_alphabet.add(START)
        # self.word_alphabet.add(UNKNOWN)
        # self.char_alphabet.add(START)
        # self.char_alphabet.add(UNKNOWN)
        # self.char_alphabet.add(PADDING)
        self.label_alphabet = Alphabet('label', True)
        self.tagScheme = "NoSeg"
        self.char_features = "LSTM"  ## "LSTM"/"CNN"

        self.train_texts = []
        self.dev_texts = []
        self.test_texts = []
        self.raw_texts = []

        self.train_Ids = []
        self.dev_Ids = []
        self.test_Ids = []
        self.raw_Ids = []

        self.word_emb_dim = 50
        self.char_emb_dim = 30
        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None
        self.label_size = 0
        self.word_alphabet_size = 0
        self.char_alphabet_size = 0
        self.label_alphabet_size = 0
        ### hyperparameters
        self.HP_iteration = 100
        self.HP_batch_size = 10
        self.HP_average_batch_loss = False
        self.HP_char_hidden_dim = 50
        self.HP_hidden_dim = 50
        self.HP_dropout = 0.5
        self.HP_lstm_layer = 1
        self.HP_bilstm = True
        self.HP_use_char = False
        self.HP_gpu = False
        self.HP_lr = 0.015
        self.HP_lr_decay = 0.05
        self.HP_clip = None
        self.HP_momentum = 0

    def show_data_summary(self):
        print("DATA SUMMARY START:")
        print("     Tag          scheme: %s" % (self.tagScheme))
        print("     MAX SENTENCE LENGTH: %s" % (self.MAX_SENTENCE_LENGTH))
        print("     MAX   WORD   LENGTH: %s" % (self.MAX_WORD_LENGTH))
        print("     Number   normalized: %s" % (self.number_normalized))
        print("     Word  alphabet size: %s" % (self.word_alphabet_size))
        print("     Char  alphabet size: %s" % (self.char_alphabet_size))
        print("     Label alphabet size: %s" % (self.label_alphabet_size))
        print("     Word embedding size: %s" % (self.word_emb_dim))
        print("     Char embedding size: %s" % (self.char_emb_dim))
        print("     Norm   word     emb: %s" % (self.norm_word_emb))
        print("     Norm   char     emb: %s" % (self.norm_char_emb))
        print("     Train instance number: %s" % (len(self.train_texts)))
        print("     Dev   instance number: %s" % (len(self.dev_texts)))
        print("     Test  instance number: %s" % (len(self.test_texts)))
        print("     Raw   instance number: %s" % (len(self.raw_texts)))
        print("     Hyper       iteration: %s" % (self.HP_iteration))
        print("     Hyper      batch size: %s" % (self.HP_batch_size))
        print("     Hyper   average batch: %s" % (self.HP_average_batch_loss))
        print("     Hyper              lr: %s" % (self.HP_lr))
        print("     Hyper        lr_decay: %s" % (self.HP_lr_decay))
        print("     Hyper         HP_clip: %s" % (self.HP_clip))
        print("     Hyper        momentum: %s" % (self.HP_momentum))
        print("     Hyper      hidden_dim: %s" % (self.HP_hidden_dim))
        print("     Hyper         dropout: %s" % (self.HP_dropout))
        print("     Hyper      lstm_layer: %s" % (self.HP_lstm_layer))
        print("     Hyper          bilstm: %s" % (self.HP_bilstm))
        print("     Hyper             GPU: %s" % (self.HP_gpu))
        print("     Hyper        use_char: %s" % (self.HP_use_char))
        if self.HP_use_char:
            print("             Char_features: %s" % (self.char_features))

        print("DATA SUMMARY END.")
        sys.stdout.flush()

    def refresh_label_alphabet(self, input_file):
        old_size = self.label_alphabet_size
        self.label_alphabet.clear(True)
        in_lines = open(input_file, 'r').readlines()
        for line in in_lines:
            if len(line) > 2:
                pairs = line.strip().split()
                label = pairs[-1]
                self.label_alphabet.add(label)
        self.label_alphabet_size = self.label_alphabet.size()
        startS = False
        startB = False
        for label, _ in self.label_alphabet.iteritems():
            if "S-" in label.upper():
                startS = True
            elif "B-" in label.upper():
                startB = True
        if startB:
            if startS:
                self.tagScheme = "BMES"
            else:
                self.tagScheme = "BIO"
        self.fix_alphabet()
        print("Refresh label alphabet finished: old:%s -> new:%s" %
              (old_size, self.label_alphabet_size))

    def extend_word_char_alphabet(self, input_file_list):
        old_word_size = self.word_alphabet_size
        old_char_size = self.char_alphabet_size
        for input_file in input_file_list:
            in_lines = open(input_file, 'r').readlines()
            for line in in_lines:
                if len(line) > 2:
                    pairs = line.strip().split()
                    word = pairs[0]
                    if self.number_normalized:
                        word = normalize_word(word)
                    self.word_alphabet.add(word)
                    for char in word:
                        self.char_alphabet.add(char)
        self.word_alphabet_size = self.word_alphabet.size()
        self.char_alphabet_size = self.char_alphabet.size()
        print("Extend word/char alphabet finished!")
        print("     old word:%s -> new word:%s" %
              (old_word_size, self.word_alphabet_size))
        print("     old char:%s -> new char:%s" %
              (old_char_size, self.char_alphabet_size))
        for input_file in input_file_list:
            print("     from file:%s" % (input_file))

    def build_alphabet(self, input_file):
        in_lines_string = open(input_file + ".string.txt", 'r').readlines()
        in_lines_label = open(input_file + ".label.txt", 'r').readlines()
        for line_string, line_label in zip(in_lines_string, in_lines_label):
            print(line_label)
            print(line_string)
            line_label = line_label[:-1].split(',')
            line_string = line_string[:-1]
            assert len(line_label) == len(line_string)
            for i in range(len(line_label)):
                self.label_alphabet.add(line_label[i])
                self.word_alphabet.add(line_string[i])
        self.char_alphabet.add("*")
        self.word_alphabet_size = self.word_alphabet.size()
        self.char_alphabet_size = self.char_alphabet.size()
        self.label_alphabet_size = self.label_alphabet.size()
        startS = False
        startB = False
        for label, _ in self.label_alphabet.iteritems():
            if "S-" in label.upper():
                startS = True
            elif "B-" in label.upper():
                startB = True
        if startB:
            if startS:
                self.tagScheme = "BMES"
            else:
                self.tagScheme = "BIO"

    def fix_alphabet(self):
        self.word_alphabet.close()
        self.char_alphabet.close()
        self.label_alphabet.close()

    def build_word_pretrain_emb(self, emb_path):
        self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(
            emb_path, self.word_alphabet, self.word_emb_dim,
            self.norm_word_emb)

    def build_char_pretrain_emb(self, emb_path):
        self.pretrain_char_embedding, self.char_emb_dim = build_pretrain_embedding(
            emb_path, self.char_alphabet, self.char_emb_dim,
            self.norm_char_emb)

    def generate_instance(self, input_file, name):
        self.fix_alphabet()
        if name == "train":
            self.train_texts, self.train_Ids = read_instance(
                input_file, self.word_alphabet, self.char_alphabet,
                self.label_alphabet, self.number_normalized,
                self.MAX_SENTENCE_LENGTH)
        elif name == "dev":
            self.dev_texts, self.dev_Ids = read_instance(
                input_file, self.word_alphabet, self.char_alphabet,
                self.label_alphabet, self.number_normalized,
                self.MAX_SENTENCE_LENGTH)
        elif name == "test":
            self.test_texts, self.test_Ids = read_instance(
                input_file, self.word_alphabet, self.char_alphabet,
                self.label_alphabet, self.number_normalized,
                self.MAX_SENTENCE_LENGTH)
        elif name == "raw":
            self.raw_texts, self.raw_Ids = read_instance(
                input_file, self.word_alphabet, self.char_alphabet,
                self.label_alphabet, self.number_normalized,
                self.MAX_SENTENCE_LENGTH)
        else:
            print(
                "Error: you can only generate train/dev/test instance! Illegal input:%s"
                % (name))

    def write_decoded_results(self, output_file, predict_results, name):
        fout = open(output_file, 'w')
        sent_num = len(predict_results)
        content_list = []
        if name == 'raw':
            content_list = self.raw_texts
        elif name == 'test':
            content_list = self.test_texts
        elif name == 'dev':
            content_list = self.dev_texts
        elif name == 'train':
            content_list = self.train_texts
        else:
            print(
                "Error: illegal name during writing predict result, name should be within train/dev/test/raw !"
            )
        assert (sent_num == len(content_list))
        for idx in range(sent_num):
            sent_length = len(predict_results[idx])
            for idy in range(sent_length):
                ## content_list[idx] is a list with [word, char, label]
                fout.write(content_list[idx][0][idy].encode('utf-8') + " " +
                           predict_results[idx][idy] + '\n')
            fout.write('\n')
        fout.close()
        print("Predict %s result has been written into file. %s" %
              (name, output_file))
コード例 #5
0
ファイル: data.py プロジェクト: TAM-Lab/TAMRepository
class Data:
    def __init__(self):
        self.MAX_SENTENCE_LENGTH = 230
        self.MAX_WORD_LENGTH = -1
        self.number_normalized = False
        self.norm_word_emb = True
        self.norm_biword_emb = True
        self.norm_gaz_emb = False
        self.word_alphabet = Alphabet('word')
        self.biword_alphabet = Alphabet('biword')
        self.char_alphabet = Alphabet('character')
        # self.word_alphabet.add(START)
        # self.word_alphabet.add(UNKNOWN)
        # self.char_alphabet.add(START)
        # self.char_alphabet.add(UNKNOWN)
        # self.char_alphabet.add(PADDING)
        self.label_alphabet = Alphabet('label', True)
        self.gaz_lower = False
        self.gaz = Gazetteer(self.gaz_lower)
        self.gaz_alphabet = Alphabet('gaz')
        self.HP_fix_gaz_emb = False
        self.HP_use_gaz = True

        self.tagScheme = "BMES"
        self.char_features = "LSTM"

        self.train_texts = []
        self.dev_texts = []
        self.test_texts = []
        self.raw_texts = []

        self.train_Ids = []
        self.dev_Ids = []
        self.test_Ids = []
        self.raw_Ids = []
        self.use_bigram = False
        self.word_emb_dim = 50
        self.biword_emb_dim = 50
        self.char_emb_dim = 50
        self.gaz_emb_dim = 50
        self.gaz_dropout = 0.5
        self.pretrain_word_embedding = None
        self.pretrain_biword_embedding = None
        self.pretrain_gaz_embedding = None
        self.label_size = 0
        self.word_alphabet_size = 0
        self.biword_alphabet_size = 0
        self.char_alphabet_size = 0
        self.label_alphabet_size = 0
        # hyperparameters
        self.HP_iteration = 100
        self.HP_batch_size = 1
        self.HP_char_hidden_dim = 50
        self.HP_hidden_dim = 200
        self.HP_dropout = 0.5
        self.HP_lstm_layer = 1
        self.HP_bilstm = True
        self.HP_use_char = True
        self.HP_gpu = False
        self.HP_lr = 0.015
        self.HP_lr_decay = 0.05
        self.HP_clip = 5.0
        self.HP_momentum = 0

    def show_data_summary(self):
        print("DATA SUMMARY START:")
        print("     Tag          scheme: %s" % (self.tagScheme))
        print("     MAX SENTENCE LENGTH: %s" % (self.MAX_SENTENCE_LENGTH))
        print("     MAX   WORD   LENGTH: %s" % (self.MAX_WORD_LENGTH))
        print("     Number   normalized: %s" % (self.number_normalized))
        print("     Use          bigram: %s" % (self.use_bigram))
        print("     Word  alphabet size: %s" % (self.word_alphabet_size))
        print("     Biword alphabet size: %s" % (self.biword_alphabet_size))
        print("     Char  alphabet size: %s" % (self.char_alphabet_size))
        print("     Gaz   alphabet size: %s" % (self.gaz_alphabet.size()))
        print("     Label alphabet size: %s" % (self.label_alphabet_size))
        print("     Word embedding size: %s" % (self.word_emb_dim))
        print("     Biword embedding size: %s" % (self.biword_emb_dim))
        print("     Char embedding size: %s" % (self.char_emb_dim))
        print("     Gaz embedding size: %s" % (self.gaz_emb_dim))
        print("     Norm     word   emb: %s" % (self.norm_word_emb))
        print("     Norm     biword emb: %s" % (self.norm_biword_emb))
        print("     Norm     gaz    emb: %s" % (self.norm_gaz_emb))
        print("     Norm   gaz  dropout: %s" % (self.gaz_dropout))
        print("     Train instance number: %s" % (len(self.train_texts)))
        print("     Dev   instance number: %s" % (len(self.dev_texts)))
        print("     Test  instance number: %s" % (len(self.test_texts)))
        print("     Raw   instance number: %s" % (len(self.raw_texts)))
        print("     Hyperpara  iteration: %s" % (self.HP_iteration))
        print("     Hyperpara  batch size: %s" % (self.HP_batch_size))
        print("     Hyperpara          lr: %s" % (self.HP_lr))
        print("     Hyperpara    lr_decay: %s" % (self.HP_lr_decay))
        print("     Hyperpara     HP_clip: %s" % (self.HP_clip))
        print("     Hyperpara    momentum: %s" % (self.HP_momentum))
        print("     Hyperpara  hidden_dim: %s" % (self.HP_hidden_dim))
        print("     Hyperpara     dropout: %s" % (self.HP_dropout))
        print("     Hyperpara  lstm_layer: %s" % (self.HP_lstm_layer))
        print("     Hyperpara      bilstm: %s" % (self.HP_bilstm))
        print("     Hyperpara         GPU: %s" % (self.HP_gpu))
        print("     Hyperpara     use_gaz: %s" % (self.HP_use_gaz))
        print("     Hyperpara fix gaz emb: %s" % (self.HP_fix_gaz_emb))
        print("     Hyperpara    use_char: %s" % (self.HP_use_char))
        if self.HP_use_char:
            print("             Char_features: %s" % (self.char_features))
        print("DATA SUMMARY END.")
        sys.stdout.flush()

    def refresh_label_alphabet(self, input_file):
        old_size = self.label_alphabet_size
        self.label_alphabet.clear(True)
        in_lines = open(input_file, 'r').readlines()
        for line in in_lines:
            if len(line) > 2:
                pairs = line.strip().split()
                label = pairs[-1]
                self.label_alphabet.add(label)
        self.label_alphabet_size = self.label_alphabet.size()
        startS = False
        startB = False
        for label, _ in self.label_alphabet.iteritems():
            if "S-" in label.upper():
                startS = True
            elif "B-" in label.upper():
                startB = True
        if startB:
            if startS:
                self.tagScheme = "BMES"
            else:
                self.tagScheme = "BIO"
        self.fix_alphabet()
        print("Refresh label alphabet finished: old:%s -> new:%s" %
              (old_size, self.label_alphabet_size))

    def build_alphabet(self, input_file):
        in_lines = open(input_file, 'r').readlines()
        for idx in xrange(len(in_lines)):
            line = in_lines[idx]
            if len(line) > 2:
                pairs = line.strip().split()
                word = pairs[0].decode('utf-8')
                if self.number_normalized:
                    word = normalize_word(word)
                #  获取label
                label = pairs[-1]
                # 安装出现顺序添加
                self.label_alphabet.add(label)
                self.word_alphabet.add(word)
                if idx < len(in_lines) - 1 and len(in_lines[idx + 1]) > 2:
                    biword = word + in_lines[
                        idx + 1].strip().split()[0].decode('utf-8')
                else:
                    biword = word + NULLKEY
                self.biword_alphabet.add(biword)
                for char in word:
                    self.char_alphabet.add(char)
        self.word_alphabet_size = self.word_alphabet.size()
        self.biword_alphabet_size = self.biword_alphabet.size()
        self.char_alphabet_size = self.char_alphabet.size()
        self.label_alphabet_size = self.label_alphabet.size()
        startS = False
        startB = False
        # 判断是否属于BIO,BMES,BIOES其中一�?
        for label, _ in self.label_alphabet.iteritems():
            if "S-" in label.upper():
                startS = True
            elif "B-" in label.upper():
                startB = True
        if startB:
            if startS:
                # 如果有S则为BMES或BIOES
                self.tagScheme = "BMES"
            else:
                # 没有则为BIO
                self.tagScheme = "BIO"

    def build_gaz_file(self, gaz_file):
        # build gaz file,initial read gaz embedding file
        if gaz_file:
            fins = open(gaz_file, 'r').readlines()
            for fin in fins:
                fin = fin.strip().split()[0].decode('utf-8')
                if fin:
                    self.gaz.insert(fin, "one_source")
            print
            "Load gaz file: ", gaz_file, " total size:", self.gaz.size()
        else:
            print
            "Gaz file is None, load nothing"

    def build_gaz_alphabet(self, input_file):
        in_lines = open(input_file, 'r').readlines()
        word_list = []
        for line in in_lines:
            if len(line) > 3:
                word = line.split()[0].decode('utf-8')
                if self.number_normalized:
                    word = normalize_word(word)
                word_list.append(word)
            else:
                w_length = len(word_list)
                for idx in range(w_length):
                    matched_entity = self.gaz.enumerateMatchList(
                        word_list[idx:])
                    for entity in matched_entity:
                        # print entity, self.gaz.searchId(entity),self.gaz.searchType(entity)
                        self.gaz_alphabet.add(entity)
                word_list = []
        print
        "gaz alphabet size:", self.gaz_alphabet.size()

    def fix_alphabet(self):
        self.word_alphabet.close()
        self.biword_alphabet.close()
        self.char_alphabet.close()
        self.label_alphabet.close()
        self.gaz_alphabet.close()

    def build_word_pretrain_emb(self, emb_path):
        print
        "build word pretrain emb..."
        self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(
            emb_path, self.word_alphabet, self.word_emb_dim,
            self.norm_word_emb)

    def build_radical_pretrain_emb(self, emb_path):
        print
        "build radical pretrain emb..."
        self.pretrain_word_embedding, self.word_emb_dim = build_radical_pretrain_embedding(
            emb_path, self.word_alphabet, self.word_emb_dim,
            self.norm_word_emb)

    def build_biword_pretrain_emb(self, emb_path):
        print
        "build biword pretrain emb..."
        self.pretrain_biword_embedding, self.biword_emb_dim = build_pretrain_embedding(
            emb_path, self.biword_alphabet, self.biword_emb_dim,
            self.norm_biword_emb)

    def build_gaz_pretrain_emb(self, emb_path):
        print
        "build gaz pretrain emb..."
        self.pretrain_gaz_embedding, self.gaz_emb_dim = build_pretrain_embedding(
            emb_path, self.gaz_alphabet, self.gaz_emb_dim, self.norm_gaz_emb)

    def generate_instance(self, input_file, name):
        self.fix_alphabet()
        if name == "train":
            self.train_texts, self.train_Ids = read_seg_instance(
                input_file, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "dev":
            self.dev_texts, self.dev_Ids = read_seg_instance(
                input_file, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "test":
            self.test_texts, self.test_Ids = read_seg_instance(
                input_file, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "raw":
            self.raw_texts, self.raw_Ids = read_seg_instance(
                input_file, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        else:
            print(
                "Error: you can only generate train/dev/test instance! Illegal input:%s"
                % (name))

    def generate_instance_with_gaz(self, input_file, name):
        self.fix_alphabet()
        if name == "train":
            self.train_texts, self.train_Ids = read_instance_with_gaz(
                input_file, self.gaz, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.gaz_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "dev":
            self.dev_texts, self.dev_Ids = read_instance_with_gaz(
                input_file, self.gaz, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.gaz_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "test":
            self.test_texts, self.test_Ids = read_instance_with_gaz(
                input_file, self.gaz, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.gaz_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "raw":
            self.raw_texts, self.raw_Ids = read_instance_with_gaz(
                input_file, self.gaz, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.gaz_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        else:
            print(
                "Error: you can only generate train/dev/test instance! Illegal input:%s"
                % (name))

    def write_decoded_results(self, output_file, predict_results, name):
        fout = open(output_file, 'w')
        sent_num = len(predict_results)
        content_list = []
        if name == 'raw':
            content_list = self.raw_texts
        elif name == 'test':
            content_list = self.test_texts
        elif name == 'dev':
            content_list = self.dev_texts
        elif name == 'train':
            content_list = self.train_texts
        else:
            print(
                "Error: illegal name during writing predict result, name should be within train/dev/test/raw !"
            )
        assert (sent_num == len(content_list))
        for idx in range(sent_num):
            sent_length = len(predict_results[idx])
            for idy in range(sent_length):
                # content_list[idx] is a list with [word, char, label]
                fout.write(content_list[idx][0][idy].encode('utf-8') + " " +
                           predict_results[idx][idy] + '\n')

            fout.write('\n')
        fout.close()
        print("Predict %s result has been written into file. %s" %
              (name, output_file))
コード例 #6
0
class Data:
    def __init__(self):
        self.MAX_SENTENCE_LENGTH = 250
        self.MAX_WORD_LENGTH = -1
        self.number_normalized = True
        # self.punctuation_filter = True
        self.norm_word_emb = True
        self.norm_biword_emb = True
        self.norm_gaz_emb = False
        self.word_alphabet = Alphabet('word')
        self.biword_alphabet = Alphabet('biword')
        self.char_alphabet = Alphabet('character')
        # self.word_alphabet.add(START)
        # self.word_alphabet.add(UNKNOWN)
        # self.char_alphabet.add(START)
        # self.char_alphabet.add(UNKNOWN)
        # self.char_alphabet.add(PADDING)
        self.label_alphabet = Alphabet('label', True)
        self.gaz_lower = False
        self.gaz = Gazetteer(self.gaz_lower)
        self.gaz_alphabet = Alphabet('gaz')
        self.HP_fix_gaz_emb = False
        self.HP_use_gaz = True

        self.tagScheme = "NoSeg"
        self.char_features = "LSTM"

        self.train_texts = []
        self.dev_texts = []
        self.test_texts = []
        self.raw_texts = []

        self.train_Ids = []
        self.dev_Ids = []
        self.test_Ids = []
        self.raw_Ids = []
        self.use_bigram = True
        self.word_emb_dim = 50
        self.biword_emb_dim = 50
        self.char_emb_dim = 30
        self.gaz_emb_dim = 50
        self.gaz_dropout = 0.5
        self.pretrain_word_embedding = None
        self.pretrain_biword_embedding = None
        self.pretrain_gaz_embedding = None
        self.label_size = 0
        self.word_alphabet_size = 0
        self.biword_alphabet_size = 0
        self.char_alphabet_size = 0
        self.label_alphabet_size = 0
        ### hyperparameters
        self.HP_iteration = 100
        self.HP_batch_size = 10
        self.HP_char_hidden_dim = 50
        self.HP_hidden_dim = 200
        self.HP_dropout = 0.5
        self.HP_lstm_layer = 1
        self.HP_bilstm = True
        self.HP_use_char = False
        self.HP_gpu = False
        self.HP_lr = 0.015
        self.HP_lr_decay = 0.05
        self.HP_clip = 5.0
        self.HP_momentum = 0

    def show_data_summary(self):
        addLogSectionMark("DATA SUMMARY")
        print("DATA SUMMARY START:")
        print("     Tag          scheme: %s" % (self.tagScheme))
        print("     MAX SENTENCE LENGTH: %s" % (self.MAX_SENTENCE_LENGTH))
        print("     MAX   WORD   LENGTH: %s" % (self.MAX_WORD_LENGTH))
        print("     Number   normalized: %s" % (self.number_normalized))
        # print("     Punctuation  filter: %s" % (self.punctuation_filter))
        print("     Use          bigram: %s" % (self.use_bigram))
        print("     Word  alphabet size: %s" % (self.word_alphabet_size))
        print("     Biword alphabet size: %s" % (self.biword_alphabet_size))
        print("     Char  alphabet size: %s" % (self.char_alphabet_size))
        print("     Gaz   alphabet size: %s" % (self.gaz_alphabet.size()))
        print("     Label alphabet size: %s" % (self.label_alphabet_size))
        print("     Word embedding size: %s" % (self.word_emb_dim))
        print("     Biword embedding size: %s" % (self.biword_emb_dim))
        print("     Char embedding size: %s" % (self.char_emb_dim))
        print("     Gaz embedding size: %s" % (self.gaz_emb_dim))
        print("     Norm     word   emb: %s" % (self.norm_word_emb))
        print("     Norm     biword emb: %s" % (self.norm_biword_emb))
        print("     Norm     gaz    emb: %s" % (self.norm_gaz_emb))
        print("     Norm   gaz  dropout: %s" % (self.gaz_dropout))
        print("     Train instance number: %s" % (len(self.train_texts)))
        print("     Dev   instance number: %s" % (len(self.dev_texts)))
        print("     Test  instance number: %s" % (len(self.test_texts)))
        print("     Raw   instance number: %s" % (len(self.raw_texts)))
        print("     Hyperpara  iteration: %s" % (self.HP_iteration))
        print("     Hyperpara  batch size: %s" % (self.HP_batch_size))
        print("     Hyperpara          lr: %s" % (self.HP_lr))
        print("     Hyperpara    lr_decay: %s" % (self.HP_lr_decay))
        print("     Hyperpara     HP_clip: %s" % (self.HP_clip))
        print("     Hyperpara    momentum: %s" % (self.HP_momentum))
        print("     Hyperpara  hidden_dim: %s" % (self.HP_hidden_dim))
        print("     Hyperpara     dropout: %s" % (self.HP_dropout))
        print("     Hyperpara  lstm_layer: %s" % (self.HP_lstm_layer))
        print("     Hyperpara      bilstm: %s" % (self.HP_bilstm))
        print("     Hyperpara         GPU: %s" % (self.HP_gpu))
        print("     Hyperpara     use_gaz: %s" % (self.HP_use_gaz))
        print("     Hyperpara fix gaz emb: %s" % (self.HP_fix_gaz_emb))
        print("     Hyperpara    use_char: %s" % (self.HP_use_char))

        logger.info("     Tag          scheme: %s" % (self.tagScheme))
        logger.info("     MAX SENTENCE LENGTH: %s" %
                    (self.MAX_SENTENCE_LENGTH))
        logger.info("     MAX   WORD   LENGTH: %s" % (self.MAX_WORD_LENGTH))
        logger.info("     Number   normalized: %s" % (self.number_normalized))
        logger.info("     Use          bigram: %s" % (self.use_bigram))
        logger.info("     Word  alphabet size: %s" % (self.word_alphabet_size))
        logger.info("     Biword alphabet size: %s" %
                    (self.biword_alphabet_size))
        logger.info("     Char  alphabet size: %s" % (self.char_alphabet_size))
        logger.info("     Gaz   alphabet size: %s" %
                    (self.gaz_alphabet.size()))
        logger.info("     Label alphabet size: %s" %
                    (self.label_alphabet_size))
        logger.info("     Word embedding size: %s" % (self.word_emb_dim))
        logger.info("     Biword embedding size: %s" % (self.biword_emb_dim))
        logger.info("     Char embedding size: %s" % (self.char_emb_dim))
        logger.info("     Gaz embedding size: %s" % (self.gaz_emb_dim))
        logger.info("     Norm     word   emb: %s" % (self.norm_word_emb))
        logger.info("     Norm     biword emb: %s" % (self.norm_biword_emb))
        logger.info("     Norm     gaz    emb: %s" % (self.norm_gaz_emb))
        logger.info("     Norm   gaz  dropout: %s" % (self.gaz_dropout))
        logger.info("     Train instance number: %s" % (len(self.train_texts)))
        logger.info("     Dev   instance number: %s" % (len(self.dev_texts)))
        logger.info("     Test  instance number: %s" % (len(self.test_texts)))
        logger.info("     Raw   instance number: %s" % (len(self.raw_texts)))
        logger.info("     Hyperpara  iteration: %s" % (self.HP_iteration))
        logger.info("     Hyperpara  batch size: %s" % (self.HP_batch_size))
        logger.info("     Hyperpara          lr: %s" % (self.HP_lr))
        logger.info("     Hyperpara    lr_decay: %s" % (self.HP_lr_decay))
        logger.info("     Hyperpara     HP_clip: %s" % (self.HP_clip))
        logger.info("     Hyperpara    momentum: %s" % (self.HP_momentum))
        logger.info("     Hyperpara  hidden_dim: %s" % (self.HP_hidden_dim))
        logger.info("     Hyperpara     dropout: %s" % (self.HP_dropout))
        logger.info("     Hyperpara  lstm_layer: %s" % (self.HP_lstm_layer))
        logger.info("     Hyperpara      bilstm: %s" % (self.HP_bilstm))
        logger.info("     Hyperpara         GPU: %s" % (self.HP_gpu))
        logger.info("     Hyperpara     use_gaz: %s" % (self.HP_use_gaz))
        logger.info("     Hyperpara fix gaz emb: %s" % (self.HP_fix_gaz_emb))
        print("     Hyperpara    use_char: %s" % (self.HP_use_char))
        if self.HP_use_char:
            print("             Char_features: %s" % (self.char_features))
            logger.info("             Char_features: %s" %
                        (self.char_features))
        print("DATA SUMMARY END.")
        sys.stdout.flush()

    def refresh_label_alphabet(self, input_file):
        old_size = self.label_alphabet_size
        self.label_alphabet.clear(True)
        in_lines = open(input_file, 'r').readlines()
        for line in in_lines:
            if len(line) > 2:
                pairs = line.strip().split()
                label = pairs[-1]
                self.label_alphabet.add(label)
        self.label_alphabet_size = self.label_alphabet.size()
        startS = False
        startB = False
        for label, _ in self.label_alphabet.iteritems():
            if "S-" in label.upper():
                startS = True
            elif "B-" in label.upper():
                startB = True
        if startB:
            if startS:
                self.tagScheme = "BMES"
            else:
                self.tagScheme = "BIO"
        self.fix_alphabet()
        print("Refresh label alphabet finished: old:%s -> new:%s" %
              (old_size, self.label_alphabet_size))

    def build_alphabet(self, input_file):
        in_lines = open(input_file, 'r').readlines()
        for idx in xrange(len(in_lines)):
            line = in_lines[idx]
            if len(line) > 2:
                pairs = line.strip().split()
                word = pairs[0].decode('utf-8')
                if self.number_normalized:
                    word = normalize_word(word)
                label = pairs[-1]
                self.label_alphabet.add(label)
                self.word_alphabet.add(word)

                if idx < len(in_lines) - 1 and len(in_lines[idx + 1]) > 2:
                    biword = word + in_lines[
                        idx + 1].strip().split()[0].decode('utf-8')
                else:
                    biword = word + NULLKEY

                self.biword_alphabet.add(biword)
                for char in word:
                    self.char_alphabet.add(char)
        self.word_alphabet_size = self.word_alphabet.size()
        self.biword_alphabet_size = self.biword_alphabet.size()
        self.char_alphabet_size = self.char_alphabet.size()
        self.label_alphabet_size = self.label_alphabet.size()
        startS = False
        startB = False
        for label, _ in self.label_alphabet.iteritems():
            if "S-" in label.upper():
                startS = True
            elif "B-" in label.upper():
                startB = True
        if startB:
            if startS:
                self.tagScheme = "BMES"
            else:
                self.tagScheme = "BIO"

    def build_gaz_file(self, gaz_file):
        ## build gaz file,initial read gaz embedding file
        if gaz_file:
            fins = open(gaz_file, 'r').readlines()
            for fin in fins:
                fin = fin.strip().split()[0].decode('utf-8')
                if fin:
                    self.gaz.insert(fin, "one_source")
            print "Load gaz file: ", gaz_file, " total size:", self.gaz.size()
        else:
            print "Gaz file is None, load nothing"

    def build_gaz_alphabet(self, input_file):
        in_lines = open(input_file, 'r').readlines()
        word_list = []
        for line in in_lines:
            if len(line) > 3:
                word = line.split()[0].decode('utf-8')
                if self.number_normalized:
                    word = normalize_word(word)
                word_list.append(word)
            else:
                w_length = len(word_list)
                for idx in range(w_length):
                    matched_entity = self.gaz.enumerateMatchList(
                        word_list[idx:])
                    for entity in matched_entity:
                        # print entity, self.gaz.searchId(entity),self.gaz.searchType(entity)
                        self.gaz_alphabet.add(entity)
                word_list = []
        print "gaz alphabet size:", self.gaz_alphabet.size()

    def fix_alphabet(self):
        self.word_alphabet.close()
        self.biword_alphabet.close()
        self.char_alphabet.close()
        self.label_alphabet.close()
        self.gaz_alphabet.close()

    def build_word_pretrain_emb(self, emb_path):
        print "build word pretrain emb..."
        self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(
            emb_path, self.word_alphabet, self.word_emb_dim,
            self.norm_word_emb)

    def build_biword_pretrain_emb(self, emb_path):
        print "build biword pretrain emb..."
        self.pretrain_biword_embedding, self.biword_emb_dim = build_pretrain_embedding(
            emb_path, self.biword_alphabet, self.biword_emb_dim,
            self.norm_biword_emb)

    def build_gaz_pretrain_emb(self, emb_path):
        print "build gaz pretrain emb..."
        self.pretrain_gaz_embedding, self.gaz_emb_dim = build_pretrain_embedding(
            emb_path, self.gaz_alphabet, self.gaz_emb_dim, self.norm_gaz_emb)

    def generate_instance(self, input_file, name):
        self.fix_alphabet()
        if name == "train":
            self.train_texts, self.train_Ids = read_seg_instance(
                input_file, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "dev":
            self.dev_texts, self.dev_Ids = read_seg_instance(
                input_file, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "test":
            self.test_texts, self.test_Ids = read_seg_instance(
                input_file, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "raw":
            self.raw_texts, self.raw_Ids = read_seg_instance(
                input_file, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        else:
            print(
                "Error: you can only generate train/dev/test instance! Illegal input:%s"
                % (name))

    def generate_instance_with_gaz(self, input_file, name):
        self.fix_alphabet()
        if name == "train":
            self.train_texts, self.train_Ids = read_instance_with_gaz(
                input_file, self.gaz, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.gaz_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "dev":
            self.dev_texts, self.dev_Ids = read_instance_with_gaz(
                input_file, self.gaz, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.gaz_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "test":
            self.test_texts, self.test_Ids = read_instance_with_gaz(
                input_file, self.gaz, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.gaz_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "raw":
            self.raw_texts, self.raw_Ids = read_instance_with_gaz(
                input_file, self.gaz, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.gaz_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)

        elif name == "sentence":
            self.raw_texts, self.raw_Ids = read_instance_with_gaz_text(
                input_file, self.gaz, self.word_alphabet, self.biword_alphabet,
                self.char_alphabet, self.gaz_alphabet, self.label_alphabet,
                self.number_normalized, self.MAX_SENTENCE_LENGTH)
        else:
            print(
                "Error: you can only generate train/dev/test instance! Illegal input:%s"
                % (name))

    def write_decoded_results(self, output_file, predict_results, name):
        fout = open(output_file, 'w')
        sent_num = len(predict_results)
        content_list = []
        if name == 'raw':
            content_list = self.raw_texts
        elif name == 'test':
            content_list = self.test_texts
        elif name == 'dev':
            content_list = self.dev_texts
        elif name == 'train':
            content_list = self.train_texts
        else:
            print(
                "Error: illegal name during writing predict result, name should be within train/dev/test/raw !"
            )
        assert (sent_num == len(content_list))
        for idx in range(sent_num):
            sent_length = len(predict_results[idx])
            for idy in range(sent_length):
                ## content_list[idx] is a list with [word, char, label]
                fout.write(content_list[idx][0][idy].encode('utf-8') + " " +
                           predict_results[idx][idy] + '\n')

            fout.write('\n')
        fout.close()
        print("Predict %s result has been written into file. %s" %
              (name, output_file))

    def write_decoded_results_back(self, predict_results, name):
        sent_num = len(predict_results)
        content_list = []
        if name == 'raw':
            content_list = self.raw_texts
        elif name == 'test':
            content_list = self.test_texts
        elif name == 'dev':
            content_list = self.dev_texts
        elif name == 'train':
            content_list = self.train_texts
        else:
            print(
                "Error: illegal name during writing predict result, name should be within train/dev/test/raw !"
            )

        assert (sent_num == len(content_list))
        result = []
        for idx in range(sent_num):
            sent_length = len(predict_results[idx])
            for idy in range(sent_length):
                ## content_list[idx] is a list with [word, char, label]
                print(content_list[idx][0][idy].encode('utf-8') + " " +
                      predict_results[idx][idy] + '\n')

        for idx in range(sent_num):
            sent_length = len(predict_results[idx])

            data = {'start': '', 'end': "", 'value': '', 'entity': ''}
            value = ''
            for idy in range(sent_length):
                pre_su_item = predict_results[idx][idy].split('-')
                if pre_su_item[0] == 'S':
                    data['start'] = str(idy)
                    data['end'] = str(idy + 1)
                    data['value'] = content_list[idx][0][idy].encode('utf-8')
                    data['entity'] = pre_su_item[1]
                    result.append(data)
                    data = {'start': '', 'end': "", 'value': '', 'entity': ''}
                if pre_su_item[0] == 'B':
                    data['start'] = str(idy)
                    value = value + (content_list[idx][0][idy].encode('utf-8'))
                if pre_su_item[0] == 'E':
                    value = value + (content_list[idx][0][idy].encode('utf-8'))
                    data['end'] = str(idy + 1)
                    data['value'] = value
                    data['entity'] = pre_su_item[1]
                    result.append(data)
                    data = {'start': '', 'end': "", 'value': '', 'entity': ''}
                    value = ''
                if pre_su_item[0] == 'I':
                    value = value + (content_list[idx][0][idy].encode('utf-8'))

        return result

    def write_http_data(self, output_file, inputData, name):
        fout = open(output_file, 'w')
        get_num = len(inputData)

        start = 0
        numOfParagram = int(math.ceil(get_num / 5.0))
        num_start_sentence = start
        num_end_sentence = numOfParagram

        if name == "test":
            num_start_sentence = 0
            num_end_sentence = numOfParagram
        elif name == "dev":
            num_start_sentence = numOfParagram
            num_end_sentence = numOfParagram * 2
        elif name == "train":
            num_start_sentence = numOfParagram * 2
            num_end_sentence = get_num

        for idx in range(num_start_sentence, num_end_sentence):
            text = inputData[idx]["text"]
            entities = inputData[idx]["entities"]

            idText = 1
            inWord = False
            tagReady = False
            entity_name = ''
            for Text in text:
                ## content_list[idx] is a list with [word, char, label]
                tagReady = False

                for entity in entities:
                    if not inWord:
                        if entity['start'] + 1 == entity['end'] and entity[
                                'end'] == idText:
                            fout.write(
                                Text.encode('utf-8') + " " + "S-" +
                                entity['entity'].encode('utf-8') + '\n')
                            tagReady = True
                            break
                        if entity['start'] + 1 == idText:
                            fout.write(
                                Text.encode('utf-8') + " " + "B-" +
                                entity['entity'].encode('utf-8') + '\n')
                            tagReady = True
                            inWord = True
                            entity_name = entity['entity'].encode('utf-8')
                            break
                    else:
                        if entity['end'] == idText:
                            fout.write(
                                Text.encode('utf-8') + " " + "E-" +
                                entity_name + '\n')
                            tagReady = True
                            inWord = False
                            break

                if not tagReady:
                    if not inWord:
                        fout.write(Text.encode('utf-8') + " " + "O" + '\n')
                    else:
                        fout.write(
                            Text.encode('utf-8') + " " + "I-" + entity_name +
                            '\n')

                idText = idText + 1
            fout.write('\n')
        fout.close()

        print("Predict input data has been written into file. %s" %
              (output_file))
コード例 #7
0
class PairGenerator(object):
    '''Generate minibatches with
    realtime data combination
    '''
    def __init__(self, item_path, sub_item_path, pair_path, split_c=','):
        self.__dict__.update(locals())

        print('Loading title and category information...', time.ctime())
        sub_item_set = set()
        for line in open(sub_item_path).readlines():
            sub_item_set.add(line.split()[0])
        self.item_title = {}
        self.item_cat = {}
        self.cat2idx = {}
        self.max_len = 0
        sentence_list = []
        for line in open(item_path).readlines():
            tmp = line.split()
            item = tmp[0]
            cat = tmp[1]
            if cat not in self.cat2idx:
                self.cat2idx[cat] = len(self.cat2idx)
            title = tmp[2].split(split_c)
            self.item_title[item] = title
            self.item_cat[item] = self.cat2idx[cat]
            if item in sub_item_set:
                sentence_list.append(title)
                self.max_len = min(config.max_len, max(self.max_len,
                                                       len(title)))
        print(('%s items' % len(sentence_list)), time.ctime())

        print('Generating alphabet...', time.ctime())
        self.alphabet = Alphabet()
        add_to_vocab(sentence_list, self.alphabet)
        print(('%s words' % len(self.alphabet)), time.ctime())

        print('Generating weight from word2vec model...', time.ctime())
        self.sentence_list = sentence_list
        w2v_model = word2vec(sentence_list)
        self.w2v_weight = np.zeros((len(self.alphabet), config.w2vSize))
        for word, idx in self.alphabet.iteritems():
            if word in w2v_model.vocab:
                self.w2v_weight[idx] = w2v_model[word]

        print('Loading pairs ...', time.ctime())
        self.pair_list = open(pair_path).readlines()

    def batch(self, pair_list):
        left_in = []
        right_in = []
        left_out = []
        right_out = []
        pair_out = []
        for line in pair_list:
            tmp = line.split(',')
            left = tmp[0]
            right = tmp[1]
            pair = int(tmp[2])
            left_in.append(
                [self.alphabet[word] for word in self.item_title[left]])
            right_in.append(
                [self.alphabet[word] for word in self.item_title[right]])
            #left_out.append(self.item_cat[left])
            #right_out.append(self.item_cat[right])
            pair_out.append(pair)

        return {
            'left_in': sequence.pad_sequences(left_in, maxlen=self.max_len),
            'right_in': sequence.pad_sequences(right_in, maxlen=self.max_len),
            #'left_out': np_utils.to_categorical(left_out, nb_classes=len(self.cat2idx)),
            #'right_out': np_utils.to_categorical(right_out, nb_classes=len(self.cat2idx)),
            'pair_out': np.array(pair_out)
        }

    def fetch_all(self, val_split=0.2, pair_list=None):
        if not pair_list:
            pair_list = self.pair_list
        upper_bound = int(len(pair_list) * (1 - val_split))
        train_data = self.batch(pair_list[0:upper_bound])
        val_data = self.batch(pair_list[upper_bound:])
        return (train_data, val_data)

    def batch_topic(self, item_vector_dic, pair_list):
        pair_in = []
        pair_out = []
        for line in pair_list:
            tmp = line.split(',')
            left = tmp[0]
            right = tmp[1]
            pair = int(tmp[2])
            vector_left = [0.0] * config.w2vSize
            vector_right = [0.0] * config.w2vSize
            for (pos, value) in item_vector_dic[left]:
                vector_left[pos] = value
            for (pos, value) in item_vector_dic[right]:
                vector_right[pos] = value
            pair_in.append(vector_left + vector_right)
            pair_out.append(pair)

        return {'pair_in': pair_in, 'pair_out': pair_out}

    def fetch_all_topic(self, mark, val_split=0.2, pair_list=None):
        item_vector_dic = {}
        if config.fresh:
            print('Generating topic model...', time.ctime())
            (model, dictionary) = lda(self.sentence_list)

            print('Generating topic vector...', time.ctime())
            for item in self.item_title:
                title = self.item_title[item]
                doc_bow = dictionary.doc2bow(title)
                item_vector_dic[item] = model[doc_bow]
            with open(mark + 'topic.tmp', 'w') as f:
                pickle.dump(item_vector_dic, f)
        else:
            print('Loading topic vector...', time.ctime())
            with open(mark + 'topic.tmp') as f:
                item_vector_dic = pickle.load(f)

        if not pair_list:
            pair_list = self.pair_list
        upper_bound = int(len(pair_list) * (1 - val_split))
        train_data = self.batch_topic(item_vector_dic,
                                      pair_list[0:upper_bound])
        val_data = self.batch_topic(item_vector_dic, pair_list[upper_bound:])
        return (train_data, val_data)

    def flow(self, batch_size=32, shuffle=False, seed=None):
        if seed:
            random.seed(seed)

        if shuffle:
            random.shuffle(self.pair_list)

        b = 0
        pair_num = len(self.pair_list)
        while 1:
            current_index = (b * batch_size) % pair_num
            if pair_num >= current_index + batch_size:
                current_batch_size = batch_size
            else:
                current_batch_size = pair_num - current_index
            yield self.batch(self.pair_list[current_index:current_index +
                                            current_batch_size])
コード例 #8
0
class Data:
    def __init__(self):
        self.MAX_SENTENCE_LENGTH = 512
        self.MAX_WORD_LENGTH = -1
        self.number_normalized = False
        self.word_alphabet = Alphabet('word')
        self.char_alphabet = Alphabet('character')
        self.word_alphabet.add(START)
        self.word_alphabet.add(UNKNOWN)
        self.char_alphabet.add(START)
        self.char_alphabet.add(UNKNOWN)
        self.char_alphabet.add(PADDING)
        self.label_alphabet = Alphabet('label')
        self.tagScheme = "NoSeg"

        self.train_texts = []
        self.dev_texts = []
        self.test_texts = []
        self.raw_texts = []

        self.train_Ids = []
        self.dev_Ids = []
        self.test_Ids = []
        self.raw_Ids = []

        self.word_emb_dim = 50
        self.pretrain_word_embedding = None
        self.label_size = 0
        self.word_alphabet_size = 0
        self.char_alphabet_size = 0
        self.label_alphabet_size = 0
        ### hyperparameters
        self.HP_batch_size = 10
        self.HP_hidden_dim = 200
        self.HP_dropout = 0.5
        self.HP_lstm_layer = 1
        self.HP_bilstm = True
        self.HP_use_char = True
        self.HP_gpu = False
        self.HP_lr = 0.015
        self.HP_lr_decay = 0
        self.HP_clip = 5.0
        self.HP_momentum = 0

    def show_data_summary(self):
        print("DATA SUMMARY START:")
        print("     Tag          scheme: %s" % (self.tagScheme))
        print("     MAX SENTENCE LENGTH: %s" % (self.MAX_SENTENCE_LENGTH))
        print("     MAX   WORD   LENGTH: %s" % (self.MAX_WORD_LENGTH))
        print("     Number   normalized: %s" % (self.number_normalized))
        print("     Word  alphabet size: %s" % (self.word_alphabet_size))
        print("     Char  alphabet size: %s" % (self.char_alphabet_size))
        print("     Label alphabet size: %s" % (self.label_alphabet_size))
        print("     Word embedding size: %s" % (self.word_emb_dim))
        print("     Train instance number: %s" % (len(self.train_texts)))
        print("     Dev   instance number: %s" % (len(self.dev_texts)))
        print("     Test  instance number: %s" % (len(self.test_texts)))
        print("     Raw   instance number: %s" % (len(self.raw_texts)))
        print("     Hyperpara  batch size: %s" % (self.HP_batch_size))
        print("     Hyperpara          lr: %s" % (self.HP_lr))
        print("     Hyperpara    lr_decay: %s" % (self.HP_lr_decay))
        print("     Hyperpara     HP_clip: %s" % (self.HP_clip))
        print("     Hyperpara    momentum: %s" % (self.HP_momentum))
        print("     Hyperpara  hidden_dim: %s" % (self.HP_hidden_dim))
        print("     Hyperpara     dropout: %s" % (self.HP_dropout))
        print("     Hyperpara  lstm_layer: %s" % (self.HP_lstm_layer))
        print("     Hyperpara      bilstm: %s" % (self.HP_bilstm))
        print("     Hyperpara    use_char: %s" % (self.HP_use_char))
        print("     Hyperpara         GPU: %s" % (self.HP_gpu))
        print("DATA SUMMARY END.")
        sys.stdout.flush()

    def build_alphabet(self, input_file):
        in_lines = open(input_file, 'r').readlines()
        for line in in_lines:
            if len(line) > 2:
                pairs = line.strip().split()
                word = pairs[0]
                if self.number_normalized:
                    word = normalize_word(word)
                label = pairs[-1]
                self.label_alphabet.add(label)
                self.word_alphabet.add(word)
                for char in word:
                    self.char_alphabet.add(char)
        self.word_alphabet_size = self.word_alphabet.size()
        self.char_alphabet_size = self.char_alphabet.size()
        self.label_alphabet_size = self.label_alphabet.size()
        startS = False
        startB = False
        for label, _ in self.label_alphabet.iteritems():
            if "S-" in label.upper():
                startS = True
            elif "B-" in label.upper():
                startB = True
        if startB:
            if startS:
                self.tagScheme = "BMES"
            else:
                self.tagScheme = "BIO"

    def fix_alphabet(self):
        self.word_alphabet.close()
        self.char_alphabet.close()
        self.label_alphabet.close()

    def build_word_pretrain_emb(self, emb_path, norm=False):
        self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(
            emb_path, self.word_alphabet, self.word_emb_dim, norm)

    def generate_instance(self, input_file, name):
        self.fix_alphabet()
        if name == "train":
            self.train_texts, self.train_Ids = read_instance(
                input_file, self.word_alphabet, self.char_alphabet,
                self.label_alphabet, self.number_normalized,
                self.MAX_WORD_LENGTH)
        elif name == "dev":
            self.dev_texts, self.dev_Ids = read_instance(
                input_file, self.word_alphabet, self.char_alphabet,
                self.label_alphabet, self.number_normalized,
                self.MAX_WORD_LENGTH)
        elif name == "test":
            self.test_texts, self.test_Ids = read_instance(
                input_file, self.word_alphabet, self.char_alphabet,
                self.label_alphabet, self.number_normalized,
                self.MAX_WORD_LENGTH)
        else:
            print(
                "Error: you can only generate train/dev/test instance! Illegal input:%s"
                % (name))