Пример #1
0
 def pack(self, values_str, separator=',', has_tag_name=True):
     #pack header
     header=(self.message_type, self.body_struct.size)
     buff=self.header_struct.pack(*header)
     
     #pack body
     value_list=values_str.split(separator)
     if len(value_list)!=self.field_cnt:
         #TODO:return error
         LogUtil.error("Error Format:"+str(self.message_type)+":"+str(len(value_list))+":"+str(self.field_cnt))
         #return
     target_value_list=[]    
     for field_index in range(0, self.field_cnt):
         if has_tag_name:
             field_value=value_list[field_index].split('=')[1]
         else:
             field_value=value_list[field_index]
         if self.field_list[field_index][2]=='C':
             target_value=str.encode(field_value.ljust(self.field_list[field_index][3]))
         elif self.field_list[field_index][2]=='N':
             target_value=int(field_value)
         else:    
             target_value=str.encode(field_value.ljust(self.field_list[field_index][3]))
         target_value_list.append(target_value)    
         
     buff=buff+self.body_struct.pack(*target_value_list)
     check_sum = calculate_check_sum(buff, self.header_struct.size+self.body_struct.size)
     buff=buff+self.footer_struct.pack(check_sum)
     
     return buff
Пример #2
0
 def load_from_db(self, db_file, message_type_list):
     conn=sqlite3.connect(db_file)
     curs=conn.cursor()
     if len(message_type_list)==0:
         curs.execute('select distinct message_type from message_body_def order by message_type')
         for row in curs.fetchall():
             message_type_list.append(row[0])
     query='''select t.message_type,t.field_name,t.field_desc,t.format_string,t.type_category,t.type_len,t.ref_field,t.field_order,t.struct_tag 
             from vw_message_body_def t
             where t.message_type=?
             order by t.field_order
             '''
     for message_type in message_type_list:
         field_list=[]
         field_cnt=0
         curs.execute(query, [message_type])
         for row in curs.fetchall():
             field_cnt=field_cnt+1
             LogUtil.info(row)
             if row[7]!=field_cnt-1:
                 #TODO:raiseExceptions
                 LogUtil.error("field list disorder:mesType="+str(message_type)+" field:"+str(field_cnt))
             field_list.append((row[1], row[3], row[4], row[5], row[6], row[8]))
         message_processor=MessageProcessor(message_type, field_list)
         self.message_processors[message_type]=message_processor
Пример #3
0
    def encrypt(self, content, secret, crypto_type):
        '''
        加密
        :param content: 待加密内容
        :param secret: 密钥
        :param crypto_type: 加密类型
        :return:
        '''
        if crypto_type == "AES":
            BLOCK_SIZE = 16
            pad = lambda s: s + (BLOCK_SIZE - len(s) % BLOCK_SIZE) * chr(
                BLOCK_SIZE - len(s) % BLOCK_SIZE)
            try:
                obj = AES.new(secret, AES.MODE_ECB)
                crypt = obj.encrypt(pad(content))
                return base64.b64encode(crypt).decode('utf-8')
            except:
                LogUtil().error("AES加密失败!----" + traceback.format_exc())
        elif crypto_type == "3DES":
            BLOCK_SIZE = 8
            pad = lambda s: s + (BLOCK_SIZE - len(s) % BLOCK_SIZE) * chr(
                BLOCK_SIZE - len(s) % BLOCK_SIZE)

            try:
                encryptor = DES3.new(secret, DES3.MODE_ECB)
                crypt = encryptor.encrypt(pad(content))
                return base64.b64encode(crypt).decode('utf-8')
            except:
                LogUtil().error("3DES加密失败!----" + traceback.format_exc())
Пример #4
0
def _get_lr_scheduler(args, kv):
    learning_rate = args.config.getfloat('train', 'learning_rate')
    lr_factor = args.config.getfloat('train', 'lr_factor')
    if lr_factor >= 1:
        return (learning_rate, None)
    epoch_size = args.num_examples / args.batch_size
    if 'dist' in args.kv_store:
        epoch_size /= kv.num_workers
    mode = args.config.get('common', 'mode')
    begin_epoch = 0
    if mode == "load":
        model_file = args.config.get('common', 'model_file')
        begin_epoch = int(
            model_file.split("-")[1]) if len(model_file) == 16 else int(
                model_file.split("n_epoch")[1].split("n_batch")[0])
    step_epochs = [
        int(l) for l in args.config.get('train', 'lr_step_epochs').split(',')
    ]
    for s in step_epochs:
        if begin_epoch >= s:
            learning_rate *= lr_factor
    if learning_rate != args.config.getfloat('train', 'learning_rate'):
        log = LogUtil().getlogger()
        log.info('Adjust learning rate to %e for epoch %d' %
                 (learning_rate, begin_epoch))

    steps = [
        epoch_size * (x - begin_epoch) for x in step_epochs
        if x - begin_epoch > 0
    ]
    return (learning_rate,
            mx.lr_scheduler.MultiFactorScheduler(step=steps,
                                                 factor=args.lr_factor))
Пример #5
0
 def replyFinished(self, reply):
     if (reply.error()):
         LogUtil.error(reply.errorString())
     else:
         file = QFile(self.fileName)
         if file.open(QIODevice.WriteOnly):
             file.write(reply.readAll())
             file.flush()
             file.close()
     reply.deleteLater()
     reply.close()
Пример #6
0
    def load_metadata_from_desc_file(
        self,
        desc_file,
        partition='train',
        max_duration=16.0,
    ):
        """ Read metadata from the description file
            (possibly takes long, depending on the filesize)
        Params:
            desc_file (str):  Path to a JSON-line file that contains labels and
                paths to the audio files
            partition (str): One of 'train', 'validation' or 'test'
            max_duration (float): In seconds, the maximum duration of
                utterances to train or test on
        """
        logger = LogUtil().getlogger()
        logger.info('Reading description file: {} for partition: {}'.format(
            desc_file, partition))
        audio_paths, durations, texts = [], [], []
        with open(desc_file, 'rt', encoding='UTF-8') as json_line_file:
            for line_num, json_line in enumerate(json_line_file):
                try:
                    spec = json.loads(json_line)
                    if float(spec['duration']) > max_duration:
                        continue
                    audio_paths.append(spec['key'])
                    durations.append(float(spec['duration']))
                    texts.append(spec['text'])
                except Exception as e:
                    # Change to (KeyError, ValueError) or
                    # (KeyError,json.decoder.JSONDecodeError), depending on
                    # json module version
                    logger.warn(str(e))
                    logger.warn('Error reading line num #{}'.format(line_num))
                    logger.warn('line {}'.format(json_line.decode("utf-8")))

        if partition == 'train':
            self.count = len(audio_paths)
            self.train_audio_paths = audio_paths
            self.train_durations = durations
            self.train_texts = texts
        elif partition == 'validation':
            self.val_audio_paths = audio_paths
            self.val_durations = durations
            self.val_texts = texts
            self.val_count = len(audio_paths)
        elif partition == 'test':
            self.test_audio_paths = audio_paths
            self.test_durations = durations
            self.test_texts = texts
        else:
            raise Exception("Invalid partition to load metadata. "
                            "Must be train/validation/test")
Пример #7
0
def md_requestor(name, status, req_queue, config_file):
    LogUtil.get_instance(config_file, "db_reader")
    LogUtil.info("db_reader:"+name+" begin")

    config = configparser.ConfigParser()
    config.read(config_file)

    factory = msg.MessageProcessorFactory()
    db_file = config.get("message_config", "db_file")
    factory.load_from_db(db_file, [])

    trade_db_file = config.get("reqresp", "db_file")
    read_interval = config.getfloat("reqresp", "read_interval")
    if read_interval is None:
        read_interval = 0.5

    last_req_num = 0
    conn = sqlite3.connect(trade_db_file)
    query_curs = conn.cursor()
    update_curs = conn.cursor()
    query_unreported = """SELECT reqnum,message_type,appid,oms_order_id,order_status,req_text
        FROM req_resp
        WHERE order_status='0'
        AND reqnum>?
        ORDER BY reqnum
    """
    update_reported = """UPDATE req_resp
    SET order_status=?,report_time=strftime('%Y-%m-%d %H:%M:%f','now')
    WHERE reqnum in(?)
    """
    last_read_cnt = 1
    while status.value == 0:
        if last_read_cnt == 0:
            time.sleep(read_interval)

        last_read_cnt = 0
        query_curs.execute(query_unreported, [last_req_num])
        for (reqnum, message_type, appid, message_id, order_status, req_text) in query_curs.fetchall():
            message_processor = factory.build_message_processor(message_type)
            send_buff = message_processor.pack(req_text)
            req_queue.put(send_buff)
            last_req_num = reqnum
            update_curs.execute(update_reported, ['2', reqnum])
            last_read_cnt = last_read_cnt+1
            if send_buff:
                LogUtil.debug("db_reader putQ:"+binascii.hexlify(send_buff).decode())
            else:
                LogUtil.debug("db_reader putQ: send_buff NULL")
        conn.commit()

    LogUtil.info("db_reader:"+name+" end")
Пример #8
0
def db_reader(name, status, req_queue, config_file):
    LogUtil.get_instance(config_file, "db_reader")
    LogUtil.info("db_reader:"+name+" begin")

    config=configparser.ConfigParser()
    config.read(config_file)
    
    factory=msg.MessageProcessorFactory()
    db_file=config.get("message_config", "db_file")
    factory.load_from_db(db_file, [])
    
    read_interval=config.getfloat("reqresp","read_interval")
    if read_interval==None:
        read_interval=0.5
    
    host=config.get("reqresp", "host")
    database=config.get("reqresp", "database")
    user=config.get("reqresp","user")
    password=config.get("reqresp", "password")
    
    last_req_num=0
    conn=pgdb.connect(database=database, host=host, user=user, password=password)
    query_curs=conn.cursor()
    update_curs=conn.cursor()
    query_unreported="""SELECT req_num,message_type,appid,oms_order_id,rept_status,req_text
        FROM req_resp
        WHERE rept_status='0'
        AND req_num>%(req_num)s
        ORDER BY req_num
    """
    query_dict={'req_num':0}
    
    update_reported="""UPDATE req_resp 
    SET rept_status=%(rept_status)s,report_time=localtimestamp
    WHERE req_num in (%(req_num)s)    
    """
    update_dict={'rept_status':'0', 'req_num':0}
    
    last_read_cnt=1
    while status.value==0:
        if last_read_cnt==0:
            time.sleep(read_interval)

        last_read_cnt=0
        query_dict['req_num']=last_req_num
        query_curs.execute(query_unreported,query_dict)
        for (req_num,message_type,appid,message_id,rept_status,req_text) in query_curs.fetchall():
            message_processor=factory.build_message_processor(message_type)
            send_buff=message_processor.pack(req_text)
            req_queue.put(send_buff)
            last_req_num=req_num
            update_dict['rept_status']='2'
            update_dict['req_num']=last_req_num
            
            update_curs.execute(update_reported,update_dict)
            last_read_cnt=last_read_cnt+1
            LogUtil.debug("db_reader putQ:"+binascii.hexlify(send_buff).decode())       
        conn.commit()    
        
    LogUtil.info("db_reader:"+name+" end")
Пример #9
0
    def update(self, labels, preds):
        check_label_shapes(labels, preds)
        if self.is_logging:
            log = LogUtil().getlogger()
            labelUtil = LabelUtil.getInstance()
        self.batch_loss = 0.
        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()

            for i in range(int(int(self.batch_size) / int(self.num_gpu))):

                l = remove_blank(label[i])
                p = []
                for k in range(int(self.seq_length)):
                    p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i]))
                p = pred_best(p)

                l_distance = levenshtein_distance(l, p)
                self.total_n_label += len(l)
                self.total_l_dist += l_distance
                this_cer = float(l_distance) / float(len(l))
                if self.is_logging:
                    log.info("label: %s " % (labelUtil.convert_num_to_word(l)))
                    log.info("pred : %s , cer: %f (distance: %d/ label length: %d)" % (
                        labelUtil.convert_num_to_word(p), this_cer, l_distance, len(l)))
                self.num_inst += 1
                self.sum_metric += this_cer
                if self.is_epoch_end:
                    loss = ctc_loss(l, pred, i, int(self.seq_length), int(self.batch_size), int(self.num_gpu))
                    self.batch_loss += loss
                    if self.is_logging:
                        log.info("loss: %f " % loss)
        self.total_ctc_loss += self.batch_loss
Пример #10
0
    def unpack_fixedlen(self, buff, separator=',', write_tag_name=True):
        message=Message()
        message.separator=separator
        message.has_tag_name=write_tag_name
        message.message_type=self.message_type
        
        (message_type, body_len)=self.header_struct.unpack_from(buff)
        if message_type!=self.message_type:
            #TODO:error format check
            LogUtil.error("Error message_type,expected="+str(self.message_type)+",act_type="+str(message_type))
            return
        #TODO:check buff size
        #TODO:check checksum        
        body_buff=buff[self.header_struct.size:]
        bytes_processed=0
        body_tuple=self.body_struct.unpack_from(body_buff)
        if len(body_tuple)!=self.field_cnt:
            return
        
        rtn_str=''
        for field_index in range(len(body_tuple)):
            if field_index>0:
                rtn_str=rtn_str+separator
            if write_tag_name:
                rtn_str=rtn_str+self.field_list[field_index][0]+'='
            
            if self.field_list[field_index][2]=='C':
                str_value=bytes.decode(body_tuple[field_index])
            elif self.field_list[field_index][2]=='N':
                str_value=str(body_tuple[field_index])
            else:    
                str_value=bytes.decode(body_tuple[field_index])
            bytes_processed=bytes_processed+self.field_list[field_index][3]
                
            rtn_str=rtn_str+str_value
            #TODO add a column to message_body_def to indicate whether the column is client Order ID
            if self.field_list[field_index][0]=='ClOrdID':
                message.client_order_id=bytes.decode(body_tuple[field_index])
            elif self.field_list[field_index][0]=='OrdStatus':
                message.order_status=bytes.decode(body_tuple[field_index])
            elif self.field_list[field_index][0]=='OrdRejReason':
                message.order_reject_reason=(body_tuple[field_index])
        message.message_str=rtn_str
        if bytes_processed!=body_len:
            LogUtil.error("bytes_process!=body_len,mesType="+str(message_type)+",body_len="+str(body_len)+",bytes_processed="+str(bytes_processed))

        return message
Пример #11
0
    def update(self, labels, preds):
        check_label_shapes(labels, preds)

        log = LogUtil().getlogger()
        labelUtil = LabelUtil.getInstance()

        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()

            for i in range(int(int(self.batch_size) / int(self.num_gpu))):

                l = remove_blank(label[i])
                p = []
                for k in range(int(self.seq_length)):
                    p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i]))
                p = pred_best(p)

                l_distance = levenshtein_distance(l, p)
                self.total_n_label += len(l)
                self.total_l_dist += l_distance
                this_cer = float(l_distance) / float(len(l))
                log.info("label: %s " % (labelUtil.convert_num_to_word(l)))
                log.info("pred : %s , cer: %f (distance: %d/ label length: %d)" % (
                    labelUtil.convert_num_to_word(p), this_cer, l_distance, len(l)))
                self.num_inst += 1
                self.sum_metric += this_cer
            if self.is_epoch_end:
                loss = ctc_loss(l, pred, i, int(self.seq_length), int(self.batch_size), int(self.num_gpu))
                self.total_ctc_loss += loss
                log.info("loss: %f " % loss)
Пример #12
0
def tgw_recv(name, status, sock, resp_queue, config_file):
    LogUtil.get_instance(config_file, "tgw_recv")
    LogUtil.info("tgw_recv:"+name+" begin")
    
    while status.value==0:
        try:
            recv_data = sock.recv(1024)
            if not recv_data:
                LogUtil.error('Recv message error!')
            else:
                LogUtil.debug('tgw recv:'+binascii.hexlify(recv_data).decode())
                #to make the recv faster, do NOT process more, just put the message to the queue
                resp_queue.put(recv_data)
        finally:
            pass
            #LogUtil.debug("")

    LogUtil.info("tgw_recv:"+name+" end")
Пример #13
0
def mdgw_recv(name, status, sock, resp_queue, config_file):
    LogUtil.get_instance(config_file, "mdgw_recv")
    LogUtil.info("mdgw_recv:"+name+" begin")
    buf_id = 0
    while status.value == 0:
        try:
            recv_data = sock.recv(1024)
            if not recv_data:
                LogUtil.error('Recv message error!')
            else:
                buf_id = buf_id+1
                src_time = datetime.datetime.now()
                # to make the recv faster, do NOT process more, just put the message to the queue
                resp_queue.put((buf_id, src_time, recv_data))
                LogUtil.debug('mdgw recv:'+binascii.hexlify(recv_data).decode())
        finally:
            pass
            # LogUtil.debug("")
    LogUtil.info("mdgw_recv:"+name+" end")
Пример #14
0
def test_unpack(buff_str):
    factory=msg.MessageProcessorFactory()
    factory.load_from_db('message_config.s3db', [])
    left_buff=binascii.a2b_hex(buff_str)
    if len(left_buff)<Message.header_len+Message.header_len:
        LogUtil.info("next message not ready1")
    (message_type, body_len)=msg.get_message_header(left_buff)
    next_message_len=body_len+ Message.header_len+Message.footer_len
    while next_message_len<=len(left_buff):
        message_processor=factory.build_message_processor(message_type)
        message=message_processor.unpack(left_buff)
        LogUtil.debug("message:"+message.toString())
        left_buff=left_buff[next_message_len:]
        if len(left_buff)<Message.header_len+Message.footer_len:
            LogUtil.debug("break @left size:"+str(len(left_buff)))
            break
        else:
            (message_type, body_len)=msg.get_message_header(left_buff)
            next_message_len=body_len+ Message.header_len+Message.footer_len
            LogUtil.debug("MsgType="+str(message_type)+" body_len="+str(body_len))
    LogUtil.debug("left buff:"+binascii.hexlify(left_buff).decode())
Пример #15
0
 def decrypt(self, content, secret, crypto_type):
     '''
     解密
     :param content: 待加密内容
     :param secret: 密钥
     :param crypto_type: 加密类型
     :return:
     '''
     if crypto_type == "AES":
         try:
             content = base64.b64decode(content)
             obj = AES.new(secret, AES.MODE_ECB)
             return obj.decrypt(content).decode('utf-8').strip()
         except:
             LogUtil().error("AES解密失败!----" + traceback.format_exc())
     elif crypto_type == "3DES":
         try:
             content = base64.b64decode(content)
             obj = DES3.new(secret, DES3.MODE_ECB)
             return obj.decrypt(content).decode('utf-8').strip()
         except:
             LogUtil().error("3DES解密失败!----" + traceback.format_exc())
Пример #16
0
    def sample_normalize(self, k_samples=1000, overwrite=False):
        """ Estimate the mean and std of the features from the training set
        Params:
            k_samples (int): Use this number of samples for estimation
        """
        log = LogUtil().getlogger()
        log.info("Calculating mean and std from samples")
        # if k_samples is negative then it goes through total dataset
        if k_samples < 0:
            audio_paths = self.audio_paths

        # using sample
        else:
            k_samples = min(k_samples, len(self.train_audio_paths))
            samples = self.rng.sample(self.train_audio_paths, k_samples)
            audio_paths = samples
        manager = Manager()
        return_dict = manager.dict()
        jobs = []
        for threadIndex in range(cpu_count()):
            proc = Process(target=self.preprocess_sample_normalize, args=(threadIndex, audio_paths, overwrite, return_dict))
            jobs.append(proc)
            proc.start()
        for proc in jobs:
            proc.join()

        feat = np.sum(np.vstack([item['feat'] for item in return_dict.values()]), axis=0)
        count = sum([item['count'] for item in return_dict.values()])
        feat_squared = np.sum(np.vstack([item['feat_squared'] for item in return_dict.values()]), axis=0)

        self.feats_mean = feat / float(count)
        self.feats_std = np.sqrt(feat_squared / float(count) - np.square(self.feats_mean))
        np.savetxt(
            generate_file_path(self.save_dir, self.model_name, 'feats_mean'), self.feats_mean)
        np.savetxt(
            generate_file_path(self.save_dir, self.model_name, 'feats_std'), self.feats_std)
        log.info("End calculating mean and std from samples")
Пример #17
0
    def sample_normalize(self, k_samples=1000, overwrite=False):
        """ Estimate the mean and std of the features from the training set
        Params:
            k_samples (int): Use this number of samples for estimation
        """
        log = LogUtil().getlogger()
        log.info("Calculating mean and std from samples")
        # if k_samples is negative then it goes through total dataset
        if k_samples < 0:
            audio_paths = self.audio_paths

        # using sample
        else:
            k_samples = min(k_samples, len(self.train_audio_paths))
            samples = self.rng.sample(self.train_audio_paths, k_samples)
            audio_paths = samples
        manager = Manager()
        return_dict = manager.dict()
        jobs = []
        num_processes = min(len(audio_paths), cpu_count())
        split_size = int(
            math.ceil(float(len(audio_paths)) / float(num_processes)))
        audio_paths_split = []
        for i in range(0, len(audio_paths), split_size):
            audio_paths_split.append(audio_paths[i:i + split_size])

        for thread_index in range(num_processes):
            proc = Process(target=self.preprocess_sample_normalize,
                           args=(thread_index, audio_paths_split[thread_index],
                                 overwrite, return_dict))
            jobs.append(proc)
            proc.start()
        for proc in jobs:
            proc.join()

        feat = np.sum(np.vstack(
            [item['feat'] for item in return_dict.values()]),
                      axis=0)
        count = sum([item['count'] for item in return_dict.values()])
        print(feat, count)
        feat_squared = np.sum(np.vstack(
            [item['feat_squared'] for item in return_dict.values()]),
                              axis=0)

        self.feats_mean = feat / float(count)
        self.feats_std = np.sqrt(feat_squared / float(count) -
                                 np.square(self.feats_mean))
        np.savetxt(
            generate_file_path(self.save_dir, self.model_name, 'feats_mean'),
            self.feats_mean)
        np.savetxt(
            generate_file_path(self.save_dir, self.model_name, 'feats_std'),
            self.feats_std)
        log.info("End calculating mean and std from samples")
Пример #18
0
    def load_metadata_from_desc_file(self, desc_file, partition='train',
                                     max_duration=16.0,):
        """ Read metadata from the description file
            (possibly takes long, depending on the filesize)
        Params:
            desc_file (str):  Path to a JSON-line file that contains labels and
                paths to the audio files
            partition (str): One of 'train', 'validation' or 'test'
            max_duration (float): In seconds, the maximum duration of
                utterances to train or test on
        """
        logger = LogUtil().getlogger()
        logger.info('Reading description file: {} for partition: {}'
                    .format(desc_file, partition))
        audio_paths, durations, texts = [], [], []
        with open(desc_file) as json_line_file:
            for line_num, json_line in enumerate(json_line_file):
                try:
                    spec = json.loads(json_line)
                    if float(spec['duration']) > max_duration:
                        continue
                    audio_paths.append(spec['key'])
                    durations.append(float(spec['duration']))
                    texts.append(spec['text'])
                except Exception as e:
                    # Change to (KeyError, ValueError) or
                    # (KeyError,json.decoder.JSONDecodeError), depending on
                    # json module version
                    logger.warn('Error reading line #{}: {}'
                                .format(line_num, json_line))
                    logger.warn(str(e))

        if partition == 'train':
            self.count = len(audio_paths)
            self.train_audio_paths = audio_paths
            self.train_durations = durations
            self.train_texts = texts
        elif partition == 'validation':
            self.val_audio_paths = audio_paths
            self.val_durations = durations
            self.val_texts = texts
            self.val_count = len(audio_paths)
        elif partition == 'test':
            self.test_audio_paths = audio_paths
            self.test_durations = durations
            self.test_texts = texts
        else:
            raise Exception("Invalid partition to load metadata. "
                            "Must be train/validation/test")
Пример #19
0
    def update(self, labels, preds):
        check_label_shapes(labels, preds)
        if self.is_logging:
            log = LogUtil().getlogger()
            labelUtil = LabelUtil()
        self.batch_loss = 0.
        # log.info(self.audio_paths)
        host_name = socket.gethostname()
        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()

            seq_length = len(pred) / int(
                int(self.batch_size) / int(self.num_gpu))

            for i in range(int(int(self.batch_size) / int(self.num_gpu))):
                l = remove_blank(label[i])
                p = []
                probs = []
                for k in range(int(seq_length)):
                    p.append(
                        np.argmax(pred[
                            k * int(int(self.batch_size) / int(self.num_gpu)) +
                            i]))
                    probs.append(
                        pred[k * int(int(self.batch_size) / int(self.num_gpu))
                             + i])
                p = pred_best(p)

                l_distance = levenshtein_distance(l, p)
                # l_distance = editdistance.eval(labelUtil.convert_num_to_word(l).split(" "), res)
                self.total_n_label += len(l)
                self.total_l_dist += l_distance
                this_cer = float(l_distance) / float(len(l))
                if self.is_logging and this_cer > 0.4:
                    log.info("%s label: %s " %
                             (host_name, labelUtil.convert_num_to_word(l)))
                    log.info(
                        "%s pred : %s , cer: %f (distance: %d/ label length: %d)"
                        % (host_name, labelUtil.convert_num_to_word(p),
                           this_cer, l_distance, len(l)))
                    # log.info("ctc_loss: %.2f" % ctc_loss(l, pred, i, int(seq_length), int(self.batch_size), int(self.num_gpu)))
                self.num_inst += 1
                self.sum_metric += this_cer
                # if self.is_epoch_end:
                #    loss = ctc_loss(l, pred, i, int(seq_length), int(self.batch_size), int(self.num_gpu))
                #    self.batch_loss += loss
                #    if self.is_logging:
                #        log.info("loss: %f " % loss)
        self.total_ctc_loss += 0  # self.batch_loss
Пример #20
0
 def cal_signature(self, params):
     '''
     生成MD5加密过的签名
     :param params: Dict格式
     :return:
     '''
     try:
         if params == None:
             params = {}
         if (isinstance(params, dict)):
             params = json.dumps(params,
                                 ensure_ascii=False,
                                 separators=(',', ':'))
             content = str(params) + self.__APP_ID + self.__APP_SECURITY
             m = hashlib.md5()
             if (isinstance(content, str)):
                 m.update(content.encode('utf-8'))
                 return (m.hexdigest()).upper()
     except:
         LogUtil().error("生成签名失败!----" + traceback.format_exc())
Пример #21
0
 def cal_secret(self, token, crypto_type):
     '''
     生成密钥
     :param token: token凭证
     :param crypto_type: 加密的类型
     :return:
     '''
     try:
         if (token == None or token == ""):
             token = self.__DEFAULT_TOKEN
         source_content = (token + self.__AES_BASE_KEY).encode("utf-8")
         sha_result = hashlib.sha256(source_content).hexdigest()
         if "AES" == crypto_type:
             key_length = 16
         elif "3DES" == crypto_type:
             key_length = 24
         else:
             return None
         return sha_result[0:key_length]
     except:
         LogUtil().error("生成密钥失败!----" + traceback.format_exc())
Пример #22
0
def load_data(args, wav_file):
    mode = args.config.get('common', 'mode')
    if mode not in ['train', 'predict', 'load']:
        raise Exception(
            'mode must be the one of the followings - train,predict,load')
    batch_size = args.config.getint('common', 'batch_size')

    whcs = WHCS()
    whcs.width = args.config.getint('data', 'width')
    whcs.height = args.config.getint('data', 'height')
    whcs.channel = args.config.getint('data', 'channel')
    whcs.stride = args.config.getint('data', 'stride')
    save_dir = 'checkpoints'
    model_name = args.config.get('common', 'prefix')
    is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')
    overwrite_meta_files = args.config.getboolean('train',
                                                  'overwrite_meta_files')
    overwrite_bi_graphemes_dictionary = args.config.getboolean(
        'train', 'overwrite_bi_graphemes_dictionary')
    max_duration = args.config.getfloat('data', 'max_duration')
    max_freq = args.config.getint('data', 'max_freq')
    language = args.config.get('data', 'language')

    log = LogUtil().getlogger()
    labelUtil = LabelUtil()

    # test_json = "resources/d.json"
    datagen = DataGenerator(save_dir=save_dir,
                            model_name=model_name,
                            max_freq=max_freq)
    datagen.train_audio_paths = [wav_file]
    datagen.train_durations = [get_duration_wave(wav_file)]
    datagen.train_texts = ["1 1"]
    datagen.count = 1
    # datagen.load_train_data(test_json, max_duration=max_duration)
    labelutil = load_labelutil(labelUtil, is_bi_graphemes, language="zh")
    args.config.set('arch', 'n_classes', str(labelUtil.get_count()))
    datagen.get_meta_from_file(
        np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
        np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

    is_batchnorm = args.config.getboolean('arch', 'is_batchnorm')
    if batch_size == 1 and is_batchnorm and (mode == 'train'
                                             or mode == 'load'):
        raise Warning('batch size 1 is too small for is_batchnorm')

    max_t_count = datagen.get_max_seq_length(partition="test")
    max_label_length = \
        datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes)

    args.config.set('arch', 'max_t_count', str(max_t_count))
    args.config.set('arch', 'max_label_length', str(max_label_length))
    from importlib import import_module
    prepare_data_template = import_module(args.config.get('arch', 'arch_file'))
    init_states = prepare_data_template.prepare_data(args)
    sort_by_duration = (mode == "train")
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')
    save_feature_as_csvfile = args.config.getboolean(
        'train', 'save_feature_as_csvfile')
    if is_bucketing:
        buckets = json.loads(args.config.get('arch', 'buckets'))
        data_loaded = BucketSTTIter(
            partition="train",
            count=datagen.count,
            datagen=datagen,
            batch_size=batch_size,
            num_label=max_label_length,
            init_states=init_states,
            seq_length=max_t_count,
            width=whcs.width,
            height=whcs.height,
            sort_by_duration=sort_by_duration,
            is_bi_graphemes=is_bi_graphemes,
            buckets=buckets,
            save_feature_as_csvfile=save_feature_as_csvfile)
    else:
        data_loaded = STTIter(partition="train",
                              count=datagen.count,
                              datagen=datagen,
                              batch_size=batch_size,
                              num_label=max_label_length,
                              init_states=init_states,
                              seq_length=max_t_count,
                              width=whcs.width,
                              height=whcs.height,
                              sort_by_duration=sort_by_duration,
                              is_bi_graphemes=is_bi_graphemes,
                              save_feature_as_csvfile=save_feature_as_csvfile)

    return data_loaded, args
Пример #23
0
class LabelUtil(Singleton):
    _log = None

    # dataPath
    def __init__(self):
        self._log = LogUtil().getlogger()
        self._log.debug("LabelUtil init")

    def load_unicode_set(self, unicodeFilePath):
        self.byChar = {}
        self.byIndex = {}
        self.byList = []
        self.unicodeFilePath = unicodeFilePath

        with open(unicodeFilePath, 'rt', encoding='UTF-8') as data_file:

            self.count = 0
            for i, r in enumerate(data_file):
                ch, inx = r.rsplit(",", 1)
                self.byChar[ch] = int(inx)
                self.byIndex[int(inx)] = ch
                self.byList.append(ch)
                self.count += 1

    def to_unicode(self, src, index):
        # 1 byte
        code1 = int(ord(src[index + 0]))

        index += 1

        result = code1

        return result, index

    def convert_word_to_grapheme(self, label):

        result = []

        index = 0
        while index < len(label):
            (code, nextIndex) = self.to_unicode(label, index)

            result.append(label[index])

            index = nextIndex

        return result, "".join(result)

    def convert_word_to_num(self, word):
        try:
            label_list, _ = self.convert_word_to_grapheme(word)

            label_num = []

            for char in label_list:
                # skip word
                if char == "":
                    pass
                else:
                    label_num.append(int(self.byChar[strQ2B(char)]))

            # tuple typecast: read only, faster
            return tuple(label_num)

        except AttributeError:
            self._log.error("unicodeSet is not loaded")
            exit(-1)

        except KeyError as err:
            self._log.error("unicodeSet Key not found: %s" % err)
            exit(-1)

    def convert_bi_graphemes_to_num(self, word):
        label_num = []

        for char in word:
            # skip word
            if char == "":
                pass
            else:
                label_num.append(int(self.byChar[strQ2B(
                    char.decode("utf-8"))]))

        # tuple typecast: read only, faster
        return tuple(label_num)

    def convert_num_to_word(self, num_list):
        try:
            label_list = []
            for num in num_list:
                label_list.append(self.byIndex[num])

            return ' '.join(label_list)

        except AttributeError:
            self._log.error("unicodeSet is not loaded")
            exit(-1)

        except KeyError as err:
            self._log.error("unicodeSet Key not found: %s" % err)
            exit(-1)

    def get_count(self):
        try:
            return self.count

        except AttributeError:
            self._log.error("unicodeSet is not loaded")
            exit(-1)

    def get_unicode_file_path(self):
        try:
            return self.unicodeFilePath

        except AttributeError:
            self._log.error("unicodeSet is not loaded")
            exit(-1)

    def get_blank_index(self):
        return self.byChar["-"]

    def get_space_index(self):
        return self.byChar["$"]
import time
import sys
import traceback
import threading
import requests

from ini_op import Config
from log_util import LogUtil
from printer import Printer
from PyQt5.QtNetwork import *
from globals import *
#ini 文件
from PyQt5.QtPrintSupport import QPrinter

version = "3.0"
log = LogUtil()
config = Config("config.ini")
port = 9999
#端口
# host = 'localhost';#OM server
host = config.get("baseconf", "oms_host")
BUFSIZE = 8192
url = config.get("baseconf", "url")
bk_url = config.get("baseconf", "url")
#加密解密url
token = config.get("baseconf", "token")
token_len = config.get("baseconf", "token_len")
token_oms = config.get("baseconf", "token_oms")
token_oms_len = config.get("baseconf", "token_oms_len")
p_token = config.get("baseconf", "p_token")
p_token_len = config.get("baseconf", "p_token_len")
Пример #25
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint(
        'common', 'save_checkpoint_every_n_epoch')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            seq_length=seq_len)

    optimizer = args.config.get('train', 'optimizer')
    momentum = args.config.getfloat('train', 'momentum')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    lr_scheduler = SimpleLRScheduler(learning_rate,
                                     momentum=momentum,
                                     optimizer=optimizer)

    n_epoch = begin_epoch
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('train', 'clip_gradient')
    weight_decay = args.config.getfloat('train', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train',
                                                   'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')

    if clip_gradient == 0:
        clip_gradient = None

    module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0:
        module.init_params(initializer=get_initializer(args))

    def reset_optimizer():
        module.init_optimizer(kvstore='device',
                              optimizer=args.config.get('train', 'optimizer'),
                              optimizer_params={
                                  'clip_gradient': clip_gradient,
                                  'wd': weight_decay
                              },
                              force_init=True)

    reset_optimizer()

    while True:

        if n_epoch >= num_epoch:
            break

        eval_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):

            if data_batch.effective_sample_count is not None:
                lr_scheduler.effective_sample_count = data_batch.effective_sample_count

            module.forward_backward(data_batch)
            module.update()
            if (nbatch + 1) % show_every == 0:
                module.update_metric(eval_metric, data_batch.label)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        for nbatch, data_batch in enumerate(data_val):
            module.update_metric(eval_metric, data_batch.label)
        #module.score(eval_data=data_val, num_batch=None, eval_metric=eval_metric, reset=True)

        data_train.reset()
        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args),
                                   epoch=n_epoch,
                                   save_optimizer_states=save_optimizer_states)

        n_epoch += 1

    log.info('FINISH')
Пример #26
0
        res = ctc_greedy_decode(probs, self.labelUtil.byList)
        log.info("greedy decode cost %.3f, result is:\n%s" %
                 (time.time() - st, res))
        beam_size = 5
        from stt_metric import ctc_beam_decode
        st = time.time()
        results = ctc_beam_decode(scorer=self.scorer,
                                  beam_size=beam_size,
                                  vocab=self.labelUtil.byList,
                                  probs=probs)
        log.info("beam decode cost %.3f, result is:\n%s" %
                 (time.time() - st, "\n".join(results)))
        return "greedy:\n" + res + "\nbeam:\n" + "\n".join(results)


if __name__ == '__main__':
    if len(sys.argv) <= 1:
        raise Exception('cfg file path must be provided. ' +
                        'ex)python main.py --configfile examplecfg.cfg')
    args = parse_args(sys.argv[1])
    # set log file name
    log_filename = args.config.get('common', 'log_filename')
    log = LogUtil(filename=log_filename).getlogger()
    otherNet = Net(args)

    server = HTTPServer(('', args.config.getint('common', 'port')),
                        SimpleHTTPRequestHandler)
    log.info('Started httpserver on port')
    # Wait forever for incoming htto requests
    server.serve_forever()
Пример #27
0
class LabelUtil:
    _log = None

    # dataPath
    def __init__(self):
        self._log = LogUtil().getlogger()
        self._log.debug("LabelUtil init")

    def load_unicode_set(self, unicodeFilePath):
        self.byChar = {}
        self.byIndex = {}
        self.unicodeFilePath = unicodeFilePath

        with open(unicodeFilePath) as data_file:
            data_file = csv.reader(data_file, delimiter=',')

            self.count = 0
            for r in data_file:
                self.byChar[r[0]] = int(r[1])
                self.byIndex[int(r[1])] = r[0]
                self.count += 1

    def to_unicode(self, src, index):
        # 1 byte
        code1 = int(ord(src[index + 0]))

        index += 1

        result = code1

        return result, index

    def convert_word_to_grapheme(self, label):

        result = []

        index = 0
        while index < len(label):
            (code, nextIndex) = self.to_unicode(label, index)

            result.append(label[index])

            index = nextIndex

        return result, "".join(result)

    def convert_word_to_num(self, word):
        try:
            label_list, _ = self.convert_word_to_grapheme(word)

            label_num = []

            for char in label_list:
                # skip word
                if char == "":
                    pass
                else:
                    label_num.append(int(self.byChar[char]))

            # tuple typecast: read only, faster
            return tuple(label_num)

        except AttributeError:
            self._log.error("unicodeSet is not loaded")
            exit(-1)

        except KeyError as err:
            self._log.error("unicodeSet Key not found: %s" % err)
            exit(-1)

    def convert_bi_graphemes_to_num(self, word):
        label_num = []

        for char in word:
            # skip word
            if char == "":
                pass
            else:
                label_num.append(int(self.byChar[char]))

        # tuple typecast: read only, faster
        return tuple(label_num)

    def convert_num_to_word(self, num_list):
        try:
            label_list = []
            for num in num_list:
                label_list.append(self.byIndex[num])

            return ''.join(label_list)

        except AttributeError:
            self._log.error("unicodeSet is not loaded")
            exit(-1)

        except KeyError as err:
            self._log.error("unicodeSet Key not found: %s" % err)
            exit(-1)

    def get_count(self):
        try:
            return self.count

        except AttributeError:
            self._log.error("unicodeSet is not loaded")
            exit(-1)

    def get_unicode_file_path(self):
        try:
            return self.unicodeFilePath

        except AttributeError:
            self._log.error("unicodeSet is not loaded")
            exit(-1)

    def get_blank_index(self):
        return self.byChar["-"]

    def get_space_index(self):
        return self.byChar["$"]
Пример #28
0
 def cal_token(self, username, password, login_type="NORMAL", env=ENV):
     '''
     生成token
     :param username: 用户名
     :param password: 密码
     :param login_type: 业主端:NORMAL,供应商:SUPPLIER,运营后台:MANAGER。物业后台:PMMANAGER。
     :param env:环境。beta、Fix。
     :return:
     '''
     # 业主端:NORMAL
     if login_type == self.LOGIN_TYPE_NORMAL:
         params = {"loginName": username, "password": password}
         url = self.cal_url("sso", "sso/login", env)
         try:
             result = self.http_post(url, params, None)
             if result is not None:
                 return result["result"]
         except:
             LogUtil().error("生成NORMAL token出错!----" +
                             traceback.format_exc())
     # 供应商:SUPPLIER
     elif login_type == self.LOGIN_TYPE_SUPPLIER:
         params = {"username": username, "password": password}
         url = self.cal_url("supplier", "supplierSso/login", env)
         try:
             result = self.http_post(url,
                                     params,
                                     None,
                                     login_type=login_type)
             if result is not None:
                 token_result = json.loads(result["result"])
                 return token_result["token"]
         except:
             LogUtil().error("生成SUPPLIER token出错!----" +
                             traceback.format_exc())
     # 运营后台:MANAGER
     elif login_type == self.LOGIN_TYPE_MANAGER:
         params = {
             "username": username,
             "password": password,
             "systemType": 2
         }
         url = self.cal_url("sys-sso", "systemSso/loginSys", env)
         try:
             result = self.http_post(url,
                                     params,
                                     None,
                                     login_type=login_type)
             if result is not None:
                 return result["result"]
         except:
             LogUtil().error("生成MANAGER token出错!----" +
                             traceback.format_exc())
     # 物业后台:PMMANAGER
     elif login_type == self.LOGIN_TYPE_PMMANAGER:
         params = {
             "username": str(username),
             "password": str(password),
             "systemType": 1
         }
         url = self.cal_url("sys-sso", "systemSso/loginSys", env)
         try:
             result = self.http_post(url,
                                     params,
                                     None,
                                     login_type=login_type)
             if result is not None:
                 return result["result"]
         except:
             LogUtil().error("生成PMMANAGER token出错!----" +
                             traceback.format_exc())
     # 员工端:STAFF
     elif login_type == self.LOGIN_TYPE_STAFF:
         params = {"username": str(username), "password": str(password)}
         url = self.cal_url("sys-sso", "systemSso/login", env)
         try:
             result = self.http_post(url,
                                     params,
                                     None,
                                     login_type=self.LOGIN_TYPE_NORMAL)
             if result is not None:
                 return result["result"]
         except:
             LogUtil().error("生成STAFF token出错!----" +
                             traceback.format_exc())
Пример #29
0
 def __init__(self):
     self._log = LogUtil().getlogger()
     self._log.debug("LabelUtil init")
Пример #30
0
from stt_io_iter import STTIter
from label_util import LabelUtil
from log_util import LogUtil
import numpy as np
from stt_datagenerator import DataGenerator
from stt_metric import STTMetric
from stt_bi_graphemes_util import generate_bi_graphemes_dictionary
from stt_bucketing_module import STTBucketingModule
from stt_io_bucketingiter import BucketSTTIter
sys.path.insert(0, "../../python")

# os.environ['MXNET_ENGINE_TYPE'] = "NaiveEngine"
os.environ['MXNET_ENGINE_TYPE'] = "ThreadedEnginePerDevice"
os.environ['MXNET_ENABLE_GPU_P2P'] = "0"

logUtil = LogUtil.getInstance()


class WHCS:
    width = 0
    height = 0
    channel = 0
    stride = 0


class ConfigLogger(object):
    def __init__(self, log):
        self.__log = log

    def __call__(self, config):
        self.__log.info("Config:")
Пример #31
0
def db_writer(name, status, resp_queue, config_file):
    LogUtil.get_instance(config_file, "db_writer")
    LogUtil.info("db_writer:"+name+" begin")
    
    config=configparser.ConfigParser()
    config.read(config_file)
    
    factory=msg.MessageProcessorFactory()
    db_file=config.get("message_config", "db_file")
    factory.load_from_db(db_file, [])
    
    read_timeout=config.getint("db_writer", "resp_queue_timeout")
    LogUtil.info("read_timeout:"+str(read_timeout))
    if read_timeout<=0 or read_timeout>60:
        read_timeout=60
        LogUtil.info("read_timeout changed to:"+str(read_timeout))
    
    host=config.get("reqresp", "host")
    database=config.get("reqresp", "database")
    user=config.get("reqresp","user")
    password=config.get("reqresp", "password")
    
    conn=pgdb.connect(database=database, host=host, user=user, password=password)
    update_curs=conn.cursor()
    
    update_resp="""UPDATE req_resp 
    SET rept_status=%(rept_status)s,ex_order_status=%(ex_order_status)s, 
    err_code=%(err_code)s, resp_text=%(resp_text)s, resp_time=localtimestamp 
    WHERE oms_order_id=%(oms_order_id)s    
    """ 
    update_dict={'rept_status':'', 'ex_order_status':'', 'err_code':'', 'resp_text':'', 'oms_order_id':''}
    left_buff=b''
    while status.value==0:
        #TODO:refactor,try recv; and then process the buff
        #TODO:when processing buff, abstract the condition to next_message_ready()
        try:
            recv_buff=resp_queue.get(block=True, timeout=read_timeout)
            left_buff=left_buff+recv_buff
            if len(left_buff)<Message.header_len+Message.header_len:
                continue
            (message_type, body_len)=msg.get_message_header(left_buff)
            next_message_len=body_len+ Message.header_len+Message.footer_len
            while next_message_len<=len(left_buff):
                try:
                    message_processor=factory.build_message_processor(message_type)
                    message=message_processor.unpack(left_buff)
                    LogUtil.debug("message:"+message.toString())
                    if True:#TODO placeholder, check if was order execution report
                        update_dict['rept_status']='4'
                        update_dict['ex_order_status']=message.order_status
                        update_dict['err_code']=message.order_reject_reason
                        update_dict['resp_text']=message.message_str
                        update_dict['oms_order_id']=message.client_order_id
                        
                        update_curs.execute(update_resp, update_dict)
                        if update_curs.rowcount!=1:
                            #TODO error handle, rollback?
                            LogUtil.error("no data update"+message.toString())
                        else:
                            conn.commit()                    
                except KeyError:
                    LogUtil.error("unkown message type:"+str(message_type))
                left_buff=left_buff[next_message_len:]
                if len(left_buff)<Message.header_len+Message.footer_len:
                    break
                else:
                    (message_type, body_len)=msg.get_message_header(left_buff)
                    next_message_len=body_len+ Message.header_len+Message.footer_len
        except queue.Empty:
            LogUtil.debug("resp_queue no data")
#        except KeyError:
#            LogUtil.error("unkown message type:"+str(message_type))
        else:    
            LogUtil.info("db_writer finished processing:"+message.toString())
    LogUtil.info("db_writer:"+name+" end")
Пример #32
0
    def update(self, labels, preds):
        check_label_shapes(labels, preds)
        if self.is_logging:
            log = LogUtil().getlogger()
            labelUtil = LabelUtil()
        self.batch_loss = 0.
        shouldPrint = True
        host_name = socket.gethostname()
        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()
            seq_length = len(pred) / int(
                int(self.batch_size) / int(self.num_gpu))
            # sess = tf.Session()
            for i in range(int(int(self.batch_size) / int(self.num_gpu))):
                l = remove_blank(label[i])
                # p = []
                probs = []
                for k in range(int(seq_length)):
                    # p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i]))
                    probs.append(
                        pred[k * int(int(self.batch_size) / int(self.num_gpu))
                             + i])
                # p = pred_best(p)
                probs = np.array(probs)
                st = time.time()
                beam_size = 20
                results = ctc_beam_decode(self.scorer, beam_size,
                                          labelUtil.byList, probs)
                log.info("decode by ctc_beam cost %.2f result: %s" %
                         (time.time() - st, "\n".join(results)))

                res_str1 = ctc_greedy_decode(probs, labelUtil.byList)
                log.info("decode by pred_best: %s" % res_str1)

                # max_time_steps = int(seq_length)
                # input_log_prob_matrix_0 = np.log(probs)  # + 2.0
                #
                # # len max_time_steps array of batch_size x depth matrices
                # inputs = ([
                #   input_log_prob_matrix_0[t, :][np.newaxis, :] for t in range(max_time_steps)]
                # )
                #
                # inputs_t = [ops.convert_to_tensor(x) for x in inputs]
                # inputs_t = array_ops.stack(inputs_t)
                #
                # st = time.time()
                # # run CTC beam search decoder in tensorflow
                # decoded, log_probabilities = tf.nn.ctc_beam_search_decoder(inputs_t,
                #                                                            [max_time_steps],
                #                                                            beam_width=10,
                #                                                            top_paths=3,
                #                                                            merge_repeated=False)
                # tf_decoded, tf_log_probs = sess.run([decoded, log_probabilities])
                # st1 = time.time() - st
                # for index in range(3):
                #   tf_result = ''.join([labelUtil.byIndex.get(i + 1, ' ') for i in tf_decoded[index].values])
                #   print("%.2f elpse %.2f, %s" % (tf_log_probs[0][index], st1, tf_result))
                l_distance = editdistance.eval(
                    labelUtil.convert_num_to_word(l).split(" "), res_str1)
                # l_distance_beam = editdistance.eval(labelUtil.convert_num_to_word(l).split(" "), beam_result[0][1])
                l_distance_beam_cpp = editdistance.eval(
                    labelUtil.convert_num_to_word(l).split(" "), results[0])
                self.total_n_label += len(l)
                # self.total_l_dist_beam += l_distance_beam
                self.total_l_dist_beam_cpp += l_distance_beam_cpp
                self.total_l_dist += l_distance
                this_cer = float(l_distance) / float(len(l))
                if self.is_logging:
                    # log.info("%s label: %s " % (host_name, labelUtil.convert_num_to_word(l)))
                    # log.info("%s pred : %s , cer: %f (distance: %d/ label length: %d)" % (
                    #     host_name, labelUtil.convert_num_to_word(p), this_cer, l_distance, len(l)))
                    log.info("%s label: %s " %
                             (host_name, labelUtil.convert_num_to_word(l)))
                    log.info(
                        "%s pred : %s , cer: %f (distance: %d/ label length: %d)"
                        % (host_name, res_str1, this_cer, l_distance, len(l)))
                    # log.info("%s predb: %s , cer: %f (distance: %d/ label length: %d)" % (
                    #     host_name, " ".join(beam_result[0][1]), float(l_distance_beam) / len(l), l_distance_beam,
                    #     len(l)))
                    log.info(
                        "%s predc: %s , cer: %f (distance: %d/ label length: %d)"
                        % (host_name, " ".join(
                            results[0]), float(l_distance_beam_cpp) / len(l),
                           l_distance_beam_cpp, len(l)))
                self.total_ctc_loss += self.batch_loss
                self.placeholder = res_str1 + "\n" + "\n".join(results)
Пример #33
0
def tgw_send(name, status,  sock, req_queue, config_file):
    LogUtil.get_instance(config_file, "tgw_send")
    LogUtil.info("tgw_send:"+name+" begin")
    
    config=configparser.ConfigParser()
    config.read(config_file)
    
    heartbeat_interval=config.getint("tgw", "heartbeat_interval")
    LogUtil.info("heartbeat_interval:"+str(heartbeat_interval))
    if heartbeat_interval==None or heartbeat_interval<=0 or heartbeat_interval>=1800:
        heartbeat_interval=60
        LogUtil.info("heartbeat_interval changed to:"+str(heartbeat_interval))
    
    send_heartbeat_interval=config.getint("tgw_send", "send_heartbeat_interval")
    LogUtil.info("send_heartbeat_interval:"+str(send_heartbeat_interval))
    if send_heartbeat_interval<=0 or send_heartbeat_interval>=heartbeat_interval:
        send_heartbeat_interval=heartbeat_interval
        LogUtil.info("send_heartbeat_interval changed to:"+str(send_heartbeat_interval))
    
    read_timeout=config.getint("tgw_send", "req_queue_timeout")
    LogUtil.info("read_timeout:"+str(read_timeout))
    if read_timeout<=0 or read_timeout>send_heartbeat_interval/2:
        read_timeout=send_heartbeat_interval/2
        LogUtil.info("read_timeout changed to:"+str(read_timeout))
    
    last_send_time=0
    heartbeat=msg.packHeartbeatMessage()
    
    while status.value==0:
        try:
            message=req_queue.get(block=True, timeout=read_timeout)
        except queue.Empty:
            LogUtil.debug("req_queue no data")
            this_time=time.time()
            if this_time-last_send_time>=send_heartbeat_interval:
                req_queue.put(heartbeat)
                last_send_time=this_time
        else:
            sock.sendall(message)
            LogUtil.info("tgw_send:"+binascii.hexlify(message).decode())
    LogUtil.info("tgw_send:"+name+" end")
Пример #34
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil.getInstance().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    #seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint('common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean('train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean('train', 'enable_logging_validation_metric')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_validation_metric,is_epoch_end=True)
    # mxboard setting
    loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_train_metric,is_epoch_end=False)

    optimizer = args.config.get('optimizer', 'optimizer')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('optimizer', 'clip_gradient')
    weight_decay = args.config.getfloat('optimizer', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    optimizer_params_dictionary = json.loads(args.config.get('optimizer', 'optimizer_params_dictionary'))
    kvstore_option = args.config.get('common', 'kvstore_option')
    n_epoch=begin_epoch
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')

    if clip_gradient == 0:
        clip_gradient = None
    if is_bucketing and mode == 'load':
        model_file = args.config.get('common', 'model_file')
        model_name = os.path.splitext(model_file)[0]
        model_num_epoch = int(model_name[-4:])

        model_path = 'checkpoints/' + str(model_name[:-5])
        symbol, data_names, label_names = module(1600)
        model = STTBucketingModule(
            sym_gen=module,
            default_bucket_key=data_train.default_bucket_key,
            context=contexts)
        data_train.reset()

        model.bind(data_shapes=data_train.provide_data,
                   label_shapes=data_train.provide_label,
                   for_training=True)
        _, arg_params, aux_params = mx.model.load_checkpoint(model_path, model_num_epoch)
        model.set_params(arg_params, aux_params)
        module = model
    else:
        module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))


    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    def reset_optimizer(force_init=False):
        optimizer_params = {'lr_scheduler': lr_scheduler,
                            'clip_gradient': clip_gradient,
                            'wd': weight_decay}
        optimizer_params.update(optimizer_params_dictionary)
        module.init_optimizer(kvstore=kvstore_option,
                              optimizer=optimizer,
                              optimizer_params=optimizer_params,
                              force_init=force_init)
    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)
        data_train.reset()
        data_train.is_first_epoch = True

    #mxboard setting
    mxlog_dir = args.config.get('common', 'mxboard_log_dir')
    summary_writer = SummaryWriter(mxlog_dir)

    while True:

        if n_epoch >= num_epoch:
            break
        loss_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):
            module.forward_backward(data_batch)
            module.update()
            # mxboard setting
            if (nbatch + 1) % show_every == 0:
                module.update_metric(loss_metric, data_batch.label)
            #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch+1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch)
                module.save_checkpoint(prefix=get_checkpoint_path(args)+"n_epoch"+str(n_epoch)+"n_batch", epoch=(int((nbatch+1)/save_checkpoint_every_n_batch)-1), save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            module.update_metric(eval_metric, data_batch.label)

        # mxboard setting
        val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value()
        log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'

        data_train.reset()
        data_train.is_first_epoch = False

        # mxboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value()
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states)

        n_epoch += 1

        lr_scheduler.learning_rate=learning_rate/learning_rate_annealing

    log.info('FINISH')
Пример #35
0
import logging
from datetime import timedelta

from celery import Celery
from celery.signals import setup_logging

from log_util import LogUtil

app = Celery('task_demo1', broker='redis://127.0.0.1:6379/0')

# add logging config
conf = {
    'filename': 'log/demo1.log',
    'level': 'DEBUG',
}
LogUtil(**conf)

fn = lambda **kwargs: logging.getLogger()
setup_logging.connect(fn)

CELERYBEAT_SCHEDULE = {
    'add-every-10-second': {
        'task': 'demo1.add',
        'schedule': timedelta(seconds=10),
        'args': (3, 2)
    },
}
CELERY_TIMEZONE = 'Asia/Shanghai'


@app.task
Пример #36
0
def load_data(args):
    mode = args.config.get('common', 'mode')
    if mode not in ['train', 'predict', 'load']:
        raise Exception('mode must be the one of the followings - train,predict,load')
    batch_size = args.config.getint('common', 'batch_size')

    whcs = WHCS()
    whcs.width = args.config.getint('data', 'width')
    whcs.height = args.config.getint('data', 'height')
    whcs.channel = args.config.getint('data', 'channel')
    whcs.stride = args.config.getint('data', 'stride')
    save_dir = 'checkpoints'
    model_name = args.config.get('common', 'prefix')
    is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')
    overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files')
    overwrite_bi_graphemes_dictionary = args.config.getboolean('train', 'overwrite_bi_graphemes_dictionary')
    max_duration = args.config.getfloat('data', 'max_duration')
    language = args.config.get('data', 'language')

    log = LogUtil().getlogger()
    labelUtil = LabelUtil.getInstance()
    if mode == "train" or mode == "load":
        data_json = args.config.get('data', 'train_json')
        val_json = args.config.get('data', 'val_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(data_json, max_duration=max_duration)
        datagen.load_validation_data(val_json, max_duration=max_duration)
        if is_bi_graphemes:
            if not os.path.isfile("resources/unicodemap_en_baidu_bi_graphemes.csv") or overwrite_bi_graphemes_dictionary:
                load_labelutil(labelUtil=labelUtil, is_bi_graphemes=False, language=language)
                generate_bi_graphemes_dictionary(datagen.train_texts+datagen.val_texts)
        load_labelutil(labelUtil=labelUtil, is_bi_graphemes=is_bi_graphemes, language=language)
        args.config.set('arch', 'n_classes', str(labelUtil.get_count()))

        if mode == "train":
            if overwrite_meta_files:
                log.info("Generate mean and std from samples")
                normalize_target_k = args.config.getint('train', 'normalize_target_k')
                datagen.sample_normalize(normalize_target_k, True)
            else:
                log.info("Read mean and std from meta files")
                datagen.get_meta_from_file(
                    np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
                    np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
        elif mode == "load":
            # get feat_mean and feat_std to normalize dataset
            datagen.get_meta_from_file(
                np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
                np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

    elif mode == 'predict':
        test_json = args.config.get('data', 'test_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(test_json, max_duration=max_duration)
        labelutil = load_labelutil(labelUtil, is_bi_graphemes, language="en")
        args.config.set('arch', 'n_classes', str(labelUtil.get_count()))
        datagen.get_meta_from_file(
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

    is_batchnorm = args.config.getboolean('arch', 'is_batchnorm')
    if batch_size == 1 and is_batchnorm and (mode == 'train' or mode == 'load'):
        raise Warning('batch size 1 is too small for is_batchnorm')

    # sort file paths by its duration in ascending order to implement sortaGrad
    if mode == "train" or mode == "load":
        max_t_count = datagen.get_max_seq_length(partition="train")
        max_label_length = \
            datagen.get_max_label_length(partition="train", is_bi_graphemes=is_bi_graphemes)
    elif mode == "predict":
        max_t_count = datagen.get_max_seq_length(partition="test")
        max_label_length = \
            datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes)

    args.config.set('arch', 'max_t_count', str(max_t_count))
    args.config.set('arch', 'max_label_length', str(max_label_length))
    from importlib import import_module
    prepare_data_template = import_module(args.config.get('arch', 'arch_file'))
    init_states = prepare_data_template.prepare_data(args)
    sort_by_duration = (mode == "train")
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')
    save_feature_as_csvfile = args.config.getboolean('train', 'save_feature_as_csvfile')
    if is_bucketing:
        buckets = json.loads(args.config.get('arch', 'buckets'))
        data_loaded = BucketSTTIter(partition="train",
                                    count=datagen.count,
                                    datagen=datagen,
                                    batch_size=batch_size,
                                    num_label=max_label_length,
                                    init_states=init_states,
                                    seq_length=max_t_count,
                                    width=whcs.width,
                                    height=whcs.height,
                                    sort_by_duration=sort_by_duration,
                                    is_bi_graphemes=is_bi_graphemes,
                                    buckets=buckets,
                                    save_feature_as_csvfile=save_feature_as_csvfile)
    else:
        data_loaded = STTIter(partition="train",
                              count=datagen.count,
                              datagen=datagen,
                              batch_size=batch_size,
                              num_label=max_label_length,
                              init_states=init_states,
                              seq_length=max_t_count,
                              width=whcs.width,
                              height=whcs.height,
                              sort_by_duration=sort_by_duration,
                              is_bi_graphemes=is_bi_graphemes,
                              save_feature_as_csvfile=save_feature_as_csvfile)

    if mode == 'train' or mode == 'load':
        if is_bucketing:
            validation_loaded = BucketSTTIter(partition="validation",
                                              count=datagen.val_count,
                                              datagen=datagen,
                                              batch_size=batch_size,
                                              num_label=max_label_length,
                                              init_states=init_states,
                                              seq_length=max_t_count,
                                              width=whcs.width,
                                              height=whcs.height,
                                              sort_by_duration=False,
                                              is_bi_graphemes=is_bi_graphemes,
                                              buckets=buckets,
                                              save_feature_as_csvfile=save_feature_as_csvfile)
        else:
            validation_loaded = STTIter(partition="validation",
                                        count=datagen.val_count,
                                        datagen=datagen,
                                        batch_size=batch_size,
                                        num_label=max_label_length,
                                        init_states=init_states,
                                        seq_length=max_t_count,
                                        width=whcs.width,
                                        height=whcs.height,
                                        sort_by_duration=False,
                                        is_bi_graphemes=is_bi_graphemes,
                                        save_feature_as_csvfile=save_feature_as_csvfile)
        return data_loaded, validation_loaded, args
    elif mode == 'predict':
        return data_loaded, args
Пример #37
0
def main():
    if len(sys.argv)<2:
        print("Usage: tgw.py config_file")
        sys.exit(0)

    #read mdgw connection config
    config_file=sys.argv[1]
        
    run_status=multiprocessing.Value('i', 0)#0:运行;1:退出

    message_header_struct = struct.Struct('!II')
    logon_struct = struct.Struct('!20s20sI16s32s')
    message_footer_struct = struct.Struct('!I')
    
    send_buff = ctypes.create_string_buffer(message_header_struct.size+logon_struct.size+message_footer_struct.size)
    
    bodyLength = logon_struct.size
    message_header =(1, bodyLength)
    message_header_struct.pack_into(send_buff, 0, *message_header)    
   
    
    LogUtil.get_instance(config_file, "log")
    LogUtil.info("Begin")
    
    config=configparser.ConfigParser()
    config.read(config_file)
    sender_comp=config.get("tgw","sender_comp")
    target_comp=config.get("tgw","target_comp")    
    password=config.get("tgw","password")
    app_ver_id=config.get("tgw","app_ver_id")    
    
    sender_comp = str.encode(sender_comp.ljust(20))
    target_comp = str.encode(target_comp.ljust(20))
    password = str.encode(password.ljust(16))
    app_ver_id = str.encode(app_ver_id.ljust(32))    
    
    logon_body = (sender_comp, target_comp, 30, password, app_ver_id)
    logon_struct.pack_into(send_buff, message_header_struct.size, *logon_body)
    check_sum = msg.calculate_check_sum(send_buff, message_header_struct.size+logon_struct.size)
    message_footer_struct.pack_into(send_buff, message_header_struct.size+logon_struct.size, check_sum)
    
    sock = socket.socket(socket.AF_INET,  socket.SOCK_STREAM)
    server_ip=config.get("tgw", "ip")
    server_port=config.getint("tgw", "port")

    #logger initialize
    
    server_address = (server_ip, server_port)
    sock.connect(server_address)
    sock.settimeout(5)
    sock.setblocking(True)
    try:
        LogUtil.debug(binascii.hexlify(send_buff))
        sock.sendall(send_buff)
        recv_data = sock.recv(1024)
        if not recv_data:
            LogUtil.error('Recv error')
        else:
            LogUtil.info('Recv OK')
            LogUtil.info(binascii.hexlify(recv_data))
            unpack_recv_data = message_header_struct.unpack_from(recv_data)
            LogUtil.info(unpack_recv_data)
            #print(binascii.hexlify(recv_data))
            if unpack_recv_data[0]==1:
                LogUtil.info('Receive Login Confirm!')
                
                #TODO:send report sync
                factory=msg.MessageProcessorFactory()
                db_file=config.get("message_config", "db_file")
                factory.load_from_db(db_file, [])
                message_processor=factory.build_message_processor(5)
                buff=message_processor.pack("ReportIndex=1")
                sock.sendall(buff)
                
                req_queue=multiprocessing.Queue()
                resp_queue=multiprocessing.Queue()
                
                dbreader_proc=multiprocessing.Process(target=db_reader, args=('DBReader', run_status, req_queue, config_file))
                dbwriter_proc=multiprocessing.Process(target=db_writer, args=('DBWriter', run_status, resp_queue, config_file))
                send_proc=multiprocessing.Process(target=tgw_send, args=('TGW sender', run_status, sock, req_queue, config_file))
                recv_proc=multiprocessing.Process(target=tgw_recv, args=('TGW receiver', run_status, sock, resp_queue, config_file))
                
                dbreader_proc.start()
                dbwriter_proc.start()
                send_proc.start()
                recv_proc.start()
                
                time.sleep(10)
                
                cmd=input("enter command:")
                while cmd!='q':
                    time.sleep(2)
                    cmd=input("enter command:")
                    
                LogUtil.warning("sending exit cmd")
                run_status.value=1    
                
                dbreader_proc.join()
                dbwriter_proc.join()
                send_proc.join()
                recv_proc.join()
             
                #发送退出消息并处理应答
                logout_message=msg.packLogoutMessage()
                sock.sendall(logout_message)
                recv_data = sock.recv(1024)
                if not recv_data:
                    LogUtil.error('Recv logout_message error!')
                else:
                    LogUtil.info('Recv logout_message OK')
                    LogUtil.debug(binascii.hexlify(recv_data))
            
    finally:
        sock.close()
        LogUtil.info ('End')
Пример #38
0
    args = parse_args(sys.argv[1])
    # set parameters from cfg file
    # give random seed
    random_seed = args.config.getint('common', 'random_seed')
    mx_random_seed = args.config.getint('common', 'mx_random_seed')
    # random seed for shuffling data list
    if random_seed != -1:
        np.random.seed(random_seed)
    # set mx.random.seed to give seed for parameter initialization
    if mx_random_seed != -1:
        mx.random.seed(mx_random_seed)
    else:
        mx.random.seed(hash(datetime.now()))
    # set log file name
    log_filename = args.config.get('common', 'log_filename')
    log = LogUtil(filename=log_filename).getlogger()

    # set parameters from data section(common)
    mode = args.config.get('common', 'mode')
    if mode not in ['train', 'predict', 'load']:
        raise Exception(
            'Define mode in the cfg file first. ' +
            'train or predict or load can be the candidate for the mode.')

    # get meta file where character to number conversions are defined

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    batch_size = args.config.getint('common', 'batch_size')
    # check the number of gpus is positive divisor of the batch size for data parallel
    if batch_size % num_gpu != 0:
Пример #39
0
 def __init__(self):
     self._log = LogUtil().getlogger()
     self._log.debug("LabelUtil init")
Пример #40
0
    def __init__(self):
        if len(sys.argv) <= 1:
            raise Exception('cfg file path must be provided. ' +
                            'ex)python main.py --configfile examplecfg.cfg')
        self.args = parse_args(sys.argv[1])
        # set parameters from cfg file
        # give random seed
        self.random_seed = self.args.config.getint('common', 'random_seed')
        self.mx_random_seed = self.args.config.getint('common',
                                                      'mx_random_seed')
        # random seed for shuffling data list
        if self.random_seed != -1:
            np.random.seed(self.random_seed)
        # set mx.random.seed to give seed for parameter initialization
        if self.mx_random_seed != -1:
            mx.random.seed(self.mx_random_seed)
        else:
            mx.random.seed(hash(datetime.now()))
        # set log file name
        self.log_filename = self.args.config.get('common', 'log_filename')
        self.log = LogUtil(filename=self.log_filename).getlogger()

        # set parameters from data section(common)
        self.mode = self.args.config.get('common', 'mode')

        # get meta file where character to number conversions are defined

        self.contexts = parse_contexts(self.args)
        self.num_gpu = len(self.contexts)
        self.batch_size = self.args.config.getint('common', 'batch_size')
        # check the number of gpus is positive divisor of the batch size for data parallel
        self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm')
        self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing')

        # log current config
        self.config_logger = ConfigLogger(self.log)
        self.config_logger(self.args.config)

        default_bucket_key = 1600
        self.args.config.set('arch', 'max_t_count', str(default_bucket_key))
        self.args.config.set('arch', 'max_label_length', str(100))
        self.labelUtil = LabelUtil()
        is_bi_graphemes = self.args.config.getboolean('common',
                                                      'is_bi_graphemes')
        load_labelutil(self.labelUtil, is_bi_graphemes, language="zh")
        self.args.config.set('arch', 'n_classes',
                             str(self.labelUtil.get_count()))
        self.max_t_count = self.args.config.getint('arch', 'max_t_count')
        # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states')

        # load model
        self.model_loaded, self.model_num_epoch, self.model_path = load_model(
            self.args)

        self.model = STTBucketingModule(sym_gen=self.model_loaded,
                                        default_bucket_key=default_bucket_key,
                                        context=self.contexts)

        from importlib import import_module
        prepare_data_template = import_module(
            self.args.config.get('arch', 'arch_file'))
        init_states = prepare_data_template.prepare_data(self.args)
        width = self.args.config.getint('data', 'width')
        height = self.args.config.getint('data', 'height')
        self.model.bind(data_shapes=[
            ('data', (self.batch_size, default_bucket_key, width * height))
        ] + init_states,
                        label_shapes=[
                            ('label',
                             (self.batch_size,
                              self.args.config.getint('arch',
                                                      'max_label_length')))
                        ],
                        for_training=True)

        _, self.arg_params, self.aux_params = mx.model.load_checkpoint(
            self.model_path, self.model_num_epoch)
        self.model.set_params(self.arg_params,
                              self.aux_params,
                              allow_extra=True,
                              allow_missing=True)

        try:
            from swig_wrapper import Scorer

            vocab_list = [
                chars.encode("utf-8") for chars in self.labelUtil.byList
            ]
            self.log.info("vacab_list len is %d" % len(vocab_list))
            _ext_scorer = Scorer(0.26, 0.1,
                                 self.args.config.get('common', 'kenlm'),
                                 vocab_list)
            lm_char_based = _ext_scorer.is_character_based()
            lm_max_order = _ext_scorer.get_max_order()
            lm_dict_size = _ext_scorer.get_dict_size()
            self.log.info("language model: "
                          "is_character_based = %d," % lm_char_based +
                          " max_order = %d," % lm_max_order +
                          " dict_size = %d" % lm_dict_size)
            self.scorer = _ext_scorer
            # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True,
            #                                  scorer=_ext_scorer)
        except ImportError:
            import kenlm
            km = kenlm.Model(self.args.config.get('common', 'kenlm'))
            self.scorer = km.score
Пример #41
0
def md_responsor(name, status, resp_queue, config_file):
    LogUtil.get_instance(config_file, "md_responsor")
    LogUtil.info("md_responsor:"+name+" begin")

    config = configparser.ConfigParser()
    config.read(config_file)

    factory = msg.MessageProcessorFactory()
    db_file = config.get("message_config", "db_file")
    factory.load_from_db(db_file, [])

    read_timeout = config.getint("md_responsor", "resp_queue_timeout")
    LogUtil.info("read_timeout:"+str(read_timeout))
    if read_timeout <= 0 or read_timeout > 60:
        read_timeout = 60
        LogUtil.info("read_timeout changed to:"+str(read_timeout))

    pub_buf = config.get("md_responsor", "pub_buf")
    pub_msg = config.get("md_responsor", "pub_msg")
    pub_buf_addr = config.get("md_responsor", "pub_buf_addr")
    pub_msg_addr = config.get("md_responsor", "pub_msg_addr")

    LogUtil.debug("pub_buf:"+pub_buf+",pub_buf_addr:"+pub_buf_addr)
    LogUtil.debug("pub_msg:"+pub_msg+",pub_msg_addr:"+pub_msg_addr)

    if pub_buf:
        buf_ctx = zmq.Context()
        buf_sock = buf_ctx.socket(zmq.PUB)
        buf_sock.bind(pub_buf_addr)
    if pub_msg:
        msg_ctx = zmq.Context()
        msg_sock = msg_ctx.socket(zmq.PUB)
        msg_sock.bind(pub_msg_addr)

    left_buff = b''
    message_id = 0
    while status.value == 0:
        # TODO:refactor,try recv; and then process the buff
        # TODO:when processing buff, abstract the condition to next_message_ready()
        try:
            (buf_id, src_time, recv_buff) = resp_queue.get(block=True, timeout=read_timeout)
            if pub_buf:  # TODO:topic?
                buf_sock.send_pyobj((buf_id, src_time, recv_buff))
            left_buff = left_buff+recv_buff
            if len(left_buff) < Message.header_len+Message.header_len:
                continue
            (message_type, body_len) = msg.get_message_header(left_buff)
            next_message_len = body_len + Message.header_len+Message.footer_len
            while next_message_len <= len(left_buff):
                try:
                    message_processor = factory.build_message_processor(message_type)
                    message = message_processor.unpack(left_buff)
                    message_id = message_id+1
                    LogUtil.debug("message:"+message.toString())
                    if pub_msg:  # TODO:topic?
                        src_time = datetime.datetime.now()
                        msg_sock.send_pyobj((message_type, message_id, src_time, message.message_str))
                except KeyError:
                    LogUtil.error("unkown message type:"+str(message_type))
                except Exception as e:
                    LogUtil.error(e)
                    LogUtil.error("other error:"+traceback.print_exc())
                left_buff = left_buff[next_message_len:]
                if len(left_buff) < Message.header_len+Message.footer_len:
                    break
                else:
                    (message_type, body_len) = msg.get_message_header(left_buff)
                    next_message_len = body_len + Message.header_len+Message.footer_len
        except queue.Empty:
            LogUtil.debug("resp_queue no data")
#        except KeyError:
#            LogUtil.error("unkown message type:"+str(message_type))
        else:
            pass
    LogUtil.info("md_responsor:"+name+" end")
Пример #42
0
    def http_post(self,
                  url,
                  params,
                  token,
                  login_type="NORMAL",
                  crypto_type="3DES"):
        '''
        http Post请求
        :param url: 请求url
        :param params: 请求参数,Dict类型
        :param login_type: token类型
        :param token: 登录的token
        :param crypto_type: 加密方式,
        :return: 返回解密后的response
        '''
        # 设置headers
        headers = {
            "content-type": "application/x-json",
            "x-client-appId": self.__APP_ID,
            "x-security-version": "2.0"
        }
        if crypto_type == "3DES":
            headers["x-client-fruit"] = "mango"
        elif crypto_type == "AES":
            headers["x-client-fruit"] = "watermelon"
        if login_type == self.LOGIN_TYPE_NORMAL:
            headers["x-client-type"] = "app"
            headers["x-client-os"] = "ios"
            headers["x-security-token"] = token
        elif login_type == self.LOGIN_TYPE_SUPPLIER:
            headers["x-client-type"] = "pc"
            headers["x-client-os"] = "web"
            headers["x-supplier-token"] = token
        elif login_type == self.LOGIN_TYPE_PMMANAGER or login_type == self.LOGIN_TYPE_MANAGER:
            headers["x-client-type"] = "pc"
            headers["x-client-os"] = "web"
            headers["x-manager-token"] = token
        elif login_type == self.LOGIN_TYPE_STAFF:
            headers["x-client-type"] = "app"
            headers["x-client-os"] = "ios"
            headers["x-manager-token"] = token

        # 设置请求内容
        if params == None:  # 没有入参时,重置为空字典
            params = {}
        signature = self.cal_signature(params)
        post_content = {"signature": signature, "params": params}
        if (isinstance(post_content, dict)):
            LogUtil().info("请求入参:====>" + json.dumps(post_content,
                                                     sort_keys=True,
                                                     ensure_ascii=False,
                                                     separators=(',', ':')))
            post_content = json.dumps(post_content, separators=(',', ':'))
            # 加密后的请求内容
            secret = self.cal_secret(token, crypto_type)
            payload = self.encrypt(post_content, secret, crypto_type)

            result = None
            # 发送HTTP Post请求
            try:
                response = requests.post(url, data=payload, headers=headers)
                # 解析HTTP响应
                try:
                    if (response.status_code == 200):
                        json_temp = json.loads(response.text)
                        if (json_temp["msgCode"] == 200):
                            decrypt_result = self.decrypt(
                                str(json_temp["data"]), secret, crypto_type)
                            s = re.compile(
                                '[\\x00-\\x08\\x0b-\\x0c\\x0e-\\x1f]').sub(
                                    '', decrypt_result)
                            result = json.loads(s)
                            LogUtil().info("解密结果:====>\n" +
                                           json.dumps(result,
                                                      sort_keys=True,
                                                      ensure_ascii=False,
                                                      separators=(',', ':')))
                        else:
                            LogUtil().warning("warning的请求入参:====>\n" + url +
                                              "  " +
                                              json.dumps(params,
                                                         sort_keys=True,
                                                         ensure_ascii=False,
                                                         separators=(',',
                                                                     ':')))
                            LogUtil().warning("msgCode不等于200:====>" +
                                              str(json_temp))
                            return json_temp
                    else:
                        LogUtil().error("error的请求入参:====>\n" + url + "  " +
                                        json.dumps(params,
                                                   sort_keys=True,
                                                   ensure_ascii=False,
                                                   separators=(',', ':')))
                        LogUtil().error('HTTP返回结果:=====>' + response.text)
                except:
                    LogUtil().error("error的请求入参:====>\n" + url + "  " +
                                    json.dumps(params,
                                               sort_keys=True,
                                               ensure_ascii=False,
                                               separators=(',', ':')))
                    LogUtil().error("解析'{}'密文失败!----".format(crypto_type) +
                                    traceback.format_exc())
            except:
                LogUtil().error("error的请求入参:====>\n" + url + "  " +
                                json.dumps(params,
                                           sort_keys=True,
                                           ensure_ascii=False,
                                           separators=(',', ':')))
                LogUtil().error("http请求失败!----" + traceback.format_exc())
            return result
        else:
            return "请求参数必须是Dict类型!"
Пример #43
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint('common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean('train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean('train', 'enable_logging_validation_metric')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_validation_metric,is_epoch_end=True)
    # tensorboard setting
    loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_train_metric,is_epoch_end=False)

    optimizer = args.config.get('train', 'optimizer')
    momentum = args.config.getfloat('train', 'momentum')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('train', 'clip_gradient')
    weight_decay = args.config.getfloat('train', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    n_epoch=begin_epoch

    if clip_gradient == 0:
        clip_gradient = None

    module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))


    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    def reset_optimizer(force_init=False):
        if optimizer == "sgd":
            module.init_optimizer(kvstore='device',
                                  optimizer=optimizer,
                                  optimizer_params={'lr_scheduler': lr_scheduler,
                                                    'momentum': momentum,
                                                    'clip_gradient': clip_gradient,
                                                    'wd': weight_decay},
                                  force_init=force_init)
        elif optimizer == "adam":
            module.init_optimizer(kvstore='device',
                                  optimizer=optimizer,
                                  optimizer_params={'lr_scheduler': lr_scheduler,
                                                    #'momentum': momentum,
                                                    'clip_gradient': clip_gradient,
                                                    'wd': weight_decay},
                                  force_init=force_init)
        else:
            raise Exception('Supported optimizers are sgd and adam. If you want to implement others define them in train.py')
    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)

    #tensorboard setting
    tblog_dir = args.config.get('common', 'tensorboard_log_dir')
    summary_writer = SummaryWriter(tblog_dir)
    while True:

        if n_epoch >= num_epoch:
            break

        loss_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):

            module.forward_backward(data_batch)
            module.update()
            # tensorboard setting
            if (nbatch + 1) % show_every == 0:
                module.update_metric(loss_metric, data_batch.label)
            #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch+1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch)
                module.save_checkpoint(prefix=get_checkpoint_path(args)+"n_epoch"+str(n_epoch)+"n_batch", epoch=(int((nbatch+1)/save_checkpoint_every_n_batch)-1), save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            module.update_metric(eval_metric, data_batch.label)

        # tensorboard setting
        val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value()
        log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'

        data_train.reset()

        # tensorboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value()
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states)

        n_epoch += 1

        lr_scheduler.learning_rate=learning_rate/learning_rate_annealing

    log.info('FINISH')
Пример #44
0
class LabelUtil:
    _log = None

    # dataPath
    def __init__(self):
        self._log = LogUtil().getlogger()
        self._log.debug("LabelUtil init")

    def load_unicode_set(self, unicodeFilePath):
        self.byChar = {}
        self.byIndex = {}
        self.unicodeFilePath = unicodeFilePath

        with open(unicodeFilePath) as data_file:
            data_file = csv.reader(data_file, delimiter=',')

            self.count = 0
            for r in data_file:
                self.byChar[r[0]] = int(r[1])
                self.byIndex[int(r[1])] = r[0]
                self.count += 1


    def to_unicode(self, src, index):
        # 1 byte
        code1 = int(ord(src[index + 0]))

        index += 1

        result = code1

        return result, index

    def convert_word_to_grapheme(self, label):

        result = []

        index = 0
        while index < len(label):
            (code, nextIndex) = self.to_unicode(label, index)

            result.append(label[index])

            index = nextIndex

        return result, "".join(result)

    def convert_word_to_num(self, word):
        try:
            label_list, _ = self.convert_word_to_grapheme(word)

            label_num = []

            for char in label_list:
                # skip word
                if char == "":
                    pass
                else:
                    label_num.append(int(self.byChar[char]))

            # tuple typecast: read only, faster
            return tuple(label_num)

        except AttributeError:
            self._log.error("unicodeSet is not loaded")
            exit(-1)

        except KeyError as err:
            self._log.error("unicodeSet Key not found: %s" % err)
            exit(-1)

    def convert_bi_graphemes_to_num(self, word):
            label_num = []

            for char in word:
                # skip word
                if char == "":
                    pass
                else:
                    label_num.append(int(self.byChar[char]))

            # tuple typecast: read only, faster
            return tuple(label_num)


    def convert_num_to_word(self, num_list):
        try:
            label_list = []
            for num in num_list:
                label_list.append(self.byIndex[num])

            return ''.join(label_list)

        except AttributeError:
            self._log.error("unicodeSet is not loaded")
            exit(-1)

        except KeyError as err:
            self._log.error("unicodeSet Key not found: %s" % err)
            exit(-1)

    def get_count(self):
        try:
            return self.count

        except AttributeError:
            self._log.error("unicodeSet is not loaded")
            exit(-1)

    def get_unicode_file_path(self):
        try:
            return self.unicodeFilePath

        except AttributeError:
            self._log.error("unicodeSet is not loaded")
            exit(-1)

    def get_blank_index(self):
        return self.byChar["-"]

    def get_space_index(self):
        return self.byChar["$"]
Пример #45
0
    def run(self):
        self.dialog.pushButtonMsgRecvStart.setEnabled(False)
        self.dialog.pushButtonMsgRecvStop.setEnabled(True)
        
        LogUtil.getLoggerInstance(self.config, "msg_recv")
        LogUtil.info("msg_recv:"+" begin")
        
        config=configparser.ConfigParser()
        config.read(self.config)

        write_msg=config.get("msg_recv", "write_msg")
        if write_msg:
            md_host=config.get("msg_recv", "md_host")
            md_database=config.get("msg_recv", "md_database")
            md_user=config.get("msg_recv", "md_user")
            md_password=config.get("msg_recv", "md_password")            
            md_conn=pgdb.connect(database=md_database, host=md_host, user=md_user, password=md_password)
            md_insert_msg_cursor=md_conn.cursor()
            md_insert_msg_sql='''insert into market_data_message(data_date,insert_time,message_type,message_id,message_content,src_time)
            values(%(data_date)s,%(insert_time)s,%(message_type)s,%(message_id)s,%(message_content)s,%(src_time)s)
            '''
            md_insert_msg_dict={'data_date':0, 'message_type':0, 'message_content':''}

        msg_sub_addr=config.get("msg_recv", "msg_sub_addr")
        msg_sub_topic=config.get("msg_recv", "msg_sub_topic")
        ctx=zmq.Context()
        sock=ctx.socket(zmq.SUB)
        sock.connect(msg_sub_addr)
        sock.setsockopt_string(zmq.SUBSCRIBE, msg_sub_topic)
        
        msgRecvCnt=0
        msgWriteCnt=0
        msgUpdateTime=None
        msgErrCnt=0        
        
        recvMsgStatus=RecvMsgStatus()
        msgStatus='Running'
        
        while not self.toStop:
            try:
                (message_type, message_id, src_time, recv_msg)=sock.recv_pyobj()
                msgRecvCnt=msgRecvCnt+1
                msgUpdateTime=datetime.datetime.now().strftime('%H:%M:%S.%f')
                if write_msg:
                    md_insert_msg_dict['data_date']=0
                    md_insert_msg_dict['insert_time']=datetime.datetime.now()
                    md_insert_msg_dict['message_type']=message_type
                    md_insert_msg_dict['message_id']=message_id
                    md_insert_msg_dict['message_content']=recv_msg
                    md_insert_msg_dict['src_time']=src_time
                    md_insert_msg_cursor.execute(md_insert_msg_sql, md_insert_msg_dict)
                    #TODO
                    md_conn.commit()
                    msgWriteCnt=msgWriteCnt+1
            except Exception as e:
                msgErrCnt=msgErrCnt+1
                LogUtil.error(e)
                LogUtil.error("MsgRecvErr:msgErrCnt="+str(msgErrCnt)+",msgRecvCnt="+str(msgRecvCnt)+",msgWriteCnt="+str(msgWriteCnt))
            finally:
                pass
            recvMsgStatus.msgRecvCnt=str(msgRecvCnt)
            recvMsgStatus.msgWriteCnt=str(msgWriteCnt)
            recvMsgStatus.msgErrCnt=str(msgErrCnt)
            recvMsgStatus.msgUpdateTime=msgUpdateTime
            recvMsgStatus.msgStatus=msgStatus
            self.msgStatusUpdated.emit(recvMsgStatus)
            
            LogUtil.debug("MsgRecvErr:msgErrCnt="+str(msgErrCnt)+",msgRecvCnt="+str(msgRecvCnt)+",msgWriteCnt="+str(msgWriteCnt))
            
        self.dialog.pushButtonMsgRecvStart.setEnabled(True)
        self.dialog.pushButtonMsgRecvStop.setEnabled(False)
Пример #46
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    #seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint(
        'common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint(
        'common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean(
        'train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean(
        'train', 'enable_logging_validation_metric')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            is_logging=enable_logging_validation_metric,
                            is_epoch_end=True)
    # mxboard setting
    loss_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            is_logging=enable_logging_train_metric,
                            is_epoch_end=False)

    optimizer = args.config.get('optimizer', 'optimizer')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_annealing = args.config.getfloat('train',
                                                   'learning_rate_annealing')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('optimizer', 'clip_gradient')
    weight_decay = args.config.getfloat('optimizer', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train',
                                                   'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    optimizer_params_dictionary = json.loads(
        args.config.get('optimizer', 'optimizer_params_dictionary'))
    kvstore_option = args.config.get('common', 'kvstore_option')
    n_epoch = begin_epoch
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')

    if clip_gradient == 0:
        clip_gradient = None
    if is_bucketing and mode == 'load':
        model_file = args.config.get('common', 'model_file')
        model_name = os.path.splitext(model_file)[0]
        model_num_epoch = int(model_name[-4:])

        model_path = 'checkpoints/' + str(model_name[:-5])
        symbol, data_names, label_names = module(1600)
        model = STTBucketingModule(
            sym_gen=module,
            default_bucket_key=data_train.default_bucket_key,
            context=contexts)
        data_train.reset()

        model.bind(data_shapes=data_train.provide_data,
                   label_shapes=data_train.provide_label,
                   for_training=True)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            model_path, model_num_epoch)
        model.set_params(arg_params, aux_params)
        module = model
    else:
        module.bind(data_shapes=data_train.provide_data,
                    label_shapes=data_train.provide_label,
                    for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))

    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    def reset_optimizer(force_init=False):
        optimizer_params = {
            'lr_scheduler': lr_scheduler,
            'clip_gradient': clip_gradient,
            'wd': weight_decay
        }
        optimizer_params.update(optimizer_params_dictionary)
        module.init_optimizer(kvstore=kvstore_option,
                              optimizer=optimizer,
                              optimizer_params=optimizer_params,
                              force_init=force_init)

    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)
        data_train.reset()
        data_train.is_first_epoch = True

    #mxboard setting
    mxlog_dir = args.config.get('common', 'mxboard_log_dir')
    summary_writer = SummaryWriter(mxlog_dir)

    while True:

        if n_epoch >= num_epoch:
            break
        loss_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):
            module.forward_backward(data_batch)
            module.update()
            # mxboard setting
            if (nbatch + 1) % show_every == 0:
                module.update_metric(loss_metric, data_batch.label)
            #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch + 1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch,
                         nbatch)
                module.save_checkpoint(
                    prefix=get_checkpoint_path(args) + "n_epoch" +
                    str(n_epoch) + "n_batch",
                    epoch=(int(
                        (nbatch + 1) / save_checkpoint_every_n_batch) - 1),
                    save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            module.update_metric(eval_metric, data_batch.label)

        # mxboard setting
        val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value()
        log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer,
                 int(val_n_label - val_l_dist), val_n_label)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'

        data_train.reset()
        data_train.is_first_epoch = False

        # mxboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value(
        )
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args),
                                   epoch=n_epoch,
                                   save_optimizer_states=save_optimizer_states)

        n_epoch += 1

        lr_scheduler.learning_rate = learning_rate / learning_rate_annealing

    log.info('FINISH')
Пример #47
0
    def run(self):
        self.dialog.pushButtonBufRecvStart.setEnabled(False)
        self.dialog.pushButtonBufRecvStop.setEnabled(True)
    
        LogUtil.getLoggerInstance(self.config, "buf_recv")
        LogUtil.info("buf_recv:"+" begin")
        
        config=configparser.ConfigParser()
        config.read(self.config)

        write_buff=config.get("buf_recv", "write_buff")
        if write_buff:
            md_host=config.get("buf_recv", "md_host")
            md_database=config.get("buf_recv", "md_database")
            md_user=config.get("buf_recv", "md_user")
            md_password=config.get("buf_recv", "md_password")            
            md_conn=pgdb.connect(database=md_database, host=md_host, user=md_user, password=md_password)
            md_insert_buf_cursor=md_conn.cursor()
            #md_insert_buf_sql='''insert into market_data_buff(data_date,insert_time,buff)
            #values(%(data_date)s,localtimestamp,%(buff)s)
            #'''
            md_insert_buf_sql='''insert into market_data_buff(data_date,buff_id,insert_time,buff,src_time)
            values(%(data_date)s,%(buf_id)s,%(insert_time)s,%(buff)s,%(src_time)s)
            '''
            md_insert_buf_dict={'data_date':0, 'buf_id':0, 'insert_time':None, 'buff':'', 'src_time':None}
        
        buf_sub_addr=config.get("buf_recv", "buf_sub_addr")
        buf_sub_topic=config.get("buf_recv", "buf_sub_topic")
        ctx=zmq.Context()
        sock=ctx.socket(zmq.SUB)
        sock.connect(buf_sub_addr)
        sock.setsockopt_string(zmq.SUBSCRIBE, buf_sub_topic)
        
        bufRecvCnt=0
        bufWriteCnt=0
        bufUpdateTime=None
        bufErrCnt=0     
        bufStatus='Running'   
        recvBufStatus=RecvBufStatus()        
        
        while not self.toStop:
            try:
                (buf_id, src_time, recv_buff)=sock.recv_pyobj()
                bufRecvCnt=bufRecvCnt+1
                bufUpdateTime=datetime.datetime.now().strftime('%H:%M:%S.%f')
                if write_buff:
                    md_insert_buf_dict['data_date']=0
                    md_insert_buf_dict['insert_time']=datetime.datetime.now()
                    md_insert_buf_dict['buf_id']=buf_id
                    md_insert_buf_dict['buff']=binascii.hexlify(recv_buff).decode()
                    #msg=recv_buff.decode()
                    md_insert_buf_dict['src_time']=src_time
                    md_insert_buf_cursor.execute(md_insert_buf_sql, md_insert_buf_dict)
                    #TODO
                    md_conn.commit()
                    bufWriteCnt=bufWriteCnt+1
                    #self.dialog.bufWriteCnt.setText(str(bufWriteCnt))
            except Exception as e:
                bufErrCnt=bufErrCnt+1
                #self.dialog.bufErrCnt.setText(str(bufErrCnt))
                LogUtil.error(e)
                LogUtil.error("BufRecvErr:bufErrCnt="+str(bufErrCnt)+",bufRecvCnt="+str(bufRecvCnt)+",bufWriteCnt="+str(bufWriteCnt))
            finally:
                pass
            #TODO    
            recvBufStatus.bufRecvCnt=str(bufRecvCnt)
            recvBufStatus.bufWriteCnt=str(bufWriteCnt)
            recvBufStatus.bufErrCnt=str(bufErrCnt)
            recvBufStatus.bufUpdateTime=bufUpdateTime
            recvBufStatus.bufStatus=bufStatus
            self.bufStatusUpdated.emit(recvBufStatus)
            
            LogUtil.debug("BufRecvErr:bufErrCnt="+str(bufErrCnt)+",bufRecvCnt="+str(bufRecvCnt)+",bufWriteCnt="+str(bufWriteCnt))
            
        self.dialog.pushButtonBufRecvStart.setEnabled(True)
        self.dialog.pushButtonBufRecvStop.setEnabled(False)
Пример #48
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len)

    optimizer = args.config.get('train', 'optimizer')
    momentum = args.config.getfloat('train', 'momentum')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    lr_scheduler = SimpleLRScheduler(learning_rate, momentum=momentum, optimizer=optimizer)

    n_epoch = begin_epoch
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('train', 'clip_gradient')
    weight_decay = args.config.getfloat('train', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')

    if clip_gradient == 0:
        clip_gradient = None

    module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0:
        module.init_params(initializer=get_initializer(args))

    def reset_optimizer():
        module.init_optimizer(kvstore='device',
                              optimizer=args.config.get('train', 'optimizer'),
                              optimizer_params={'clip_gradient': clip_gradient,
                                                'wd': weight_decay},
                              force_init=True)

    reset_optimizer()

    while True:

        if n_epoch >= num_epoch:
            break

        eval_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):

            if data_batch.effective_sample_count is not None:
                lr_scheduler.effective_sample_count = data_batch.effective_sample_count

            module.forward_backward(data_batch)
            module.update()
            if (nbatch+1) % show_every == 0:
                module.update_metric(eval_metric, data_batch.label)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        for nbatch, data_batch in enumerate(data_val):
            module.update_metric(eval_metric, data_batch.label)
        #module.score(eval_data=data_val, num_batch=None, eval_metric=eval_metric, reset=True)

        data_train.reset()
        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states)

        n_epoch += 1

    log.info('FINISH')
Пример #49
0
    def unpack_varlen(self, buff, separator=',', write_tag_name=True):
        message=Message()
        message.separator=separator
        message.has_tag_name=write_tag_name
        message.message_type=self.message_type
        
        (message_type, body_len)=self.header_struct.unpack_from(buff)
        if message_type!=self.message_type:
            #TODO:error format check
            LogUtil.error("Error message_type,expected="+str(self.message_type)+",act_type="+str(message_type))
            return
        #TODO:check buff size
        #TODO:check checksum       
        bytes_processed=0
        body_left_buff=buff[self.header_struct.size:]
        #body_tuple=self.body_struct.unpack_from(body_buff)
        str_value_list=[]
        rtn_str=''
        #field_list:
        #0:field_name;
        #1:format_str:
        #2:type,eg N,C;
        #3:type_len;
        #4:ref_var_len_field
        for field_index in range(len(self.field_list)):
            if field_index>0:
                rtn_str=rtn_str+separator
            if write_tag_name:
                rtn_str=rtn_str+self.field_list[field_index][0]+'='
            
            if self.field_list[field_index][4]>0:
                #self.field_list[field_index][4] is the index of field which holding the length
                format_str="!"+str_value_list[self.field_list[field_index][4]]+"s"
                bytes_len=int(str_value_list[self.field_list[field_index][4]])
                LogUtil.debug("field:"+self.field_list[field_index][0]+" format:"+format_str)
            else:
                format_str="!"+self.field_list[field_index][1]
                bytes_len=self.field_list[field_index][3]
                LogUtil.debug("field:"+self.field_list[field_index][0]+" format:"+format_str)
            field_value=struct.unpack_from(format_str, body_left_buff)
            if self.field_list[field_index][2]=='C':
                field_str_value=bytes.decode(field_value[0])
            elif self.field_list[field_index][2]=='N':
                field_str_value=str(field_value[0])
            else:    
                field_str_value=bytes.decode(field_value[0])
        
            str_value_list.append(field_str_value)
            rtn_str=rtn_str+field_str_value
            #TODO add a column to message_body_def to indicate whether the column is client Order ID
            if self.field_list[field_index][0]=='ClOrdID':
                message.client_order_id=field_str_value
            elif self.field_list[field_index][0]=='OrdStatus':
                message.order_status=field_str_value
            elif self.field_list[field_index][0]=='OrdRejReason':
                message.order_reject_reason=field_value
            bytes_processed=bytes_processed+bytes_len    
            body_left_buff=body_left_buff[bytes_len:]
            
        message.message_str=rtn_str
        if bytes_processed!=body_len:
            LogUtil.error("bytes_process!=body_len,mesType="+str(message_type)+",body_len="+str(body_len)+",bytes_processed="+str(bytes_processed))

        return message
Пример #50
0
def load_data(args):
    mode = args.config.get('common', 'mode')
    if mode not in ['train', 'predict', 'load']:
        raise Exception(
            'mode must be the one of the followings - train,predict,load')
    batch_size = args.config.getint('common', 'batch_size')

    whcs = WHCS()
    whcs.width = args.config.getint('data', 'width')
    whcs.height = args.config.getint('data', 'height')
    whcs.channel = args.config.getint('data', 'channel')
    whcs.stride = args.config.getint('data', 'stride')
    save_dir = 'checkpoints'
    model_name = args.config.get('common', 'prefix')
    is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')
    overwrite_meta_files = args.config.getboolean('train',
                                                  'overwrite_meta_files')
    overwrite_bi_graphemes_dictionary = args.config.getboolean(
        'train', 'overwrite_bi_graphemes_dictionary')
    max_duration = args.config.getfloat('data', 'max_duration')
    language = args.config.get('data', 'language')

    log = LogUtil().getlogger()
    labelUtil = LabelUtil.getInstance()
    if mode == "train" or mode == "load":
        data_json = args.config.get('data', 'train_json')
        val_json = args.config.get('data', 'val_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(data_json, max_duration=max_duration)
        datagen.load_validation_data(val_json, max_duration=max_duration)
        if is_bi_graphemes:
            if not os.path.isfile(
                    "resources/unicodemap_en_baidu_bi_graphemes.csv"
            ) or overwrite_bi_graphemes_dictionary:
                load_labelutil(labelUtil=labelUtil,
                               is_bi_graphemes=False,
                               language=language)
                generate_bi_graphemes_dictionary(datagen.train_texts +
                                                 datagen.val_texts)
        load_labelutil(labelUtil=labelUtil,
                       is_bi_graphemes=is_bi_graphemes,
                       language=language)
        args.config.set('arch', 'n_classes', str(labelUtil.get_count()))

        if mode == "train":
            if overwrite_meta_files:
                log.info("Generate mean and std from samples")
                normalize_target_k = args.config.getint(
                    'train', 'normalize_target_k')
                datagen.sample_normalize(normalize_target_k, True)
            else:
                log.info("Read mean and std from meta files")
                datagen.get_meta_from_file(
                    np.loadtxt(
                        generate_file_path(save_dir, model_name,
                                           'feats_mean')),
                    np.loadtxt(
                        generate_file_path(save_dir, model_name, 'feats_std')))
        elif mode == "load":
            # get feat_mean and feat_std to normalize dataset
            datagen.get_meta_from_file(
                np.loadtxt(
                    generate_file_path(save_dir, model_name, 'feats_mean')),
                np.loadtxt(
                    generate_file_path(save_dir, model_name, 'feats_std')))

    elif mode == 'predict':
        test_json = args.config.get('data', 'test_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(test_json, max_duration=max_duration)
        labelutil = load_labelutil(labelUtil, is_bi_graphemes, language="en")
        args.config.set('arch', 'n_classes', str(labelUtil.get_count()))
        datagen.get_meta_from_file(
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

    is_batchnorm = args.config.getboolean('arch', 'is_batchnorm')
    if batch_size == 1 and is_batchnorm and (mode == 'train'
                                             or mode == 'load'):
        raise Warning('batch size 1 is too small for is_batchnorm')

    # sort file paths by its duration in ascending order to implement sortaGrad
    if mode == "train" or mode == "load":
        max_t_count = datagen.get_max_seq_length(partition="train")
        max_label_length = \
            datagen.get_max_label_length(partition="train", is_bi_graphemes=is_bi_graphemes)
    elif mode == "predict":
        max_t_count = datagen.get_max_seq_length(partition="test")
        max_label_length = \
            datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes)

    args.config.set('arch', 'max_t_count', str(max_t_count))
    args.config.set('arch', 'max_label_length', str(max_label_length))
    from importlib import import_module
    prepare_data_template = import_module(args.config.get('arch', 'arch_file'))
    init_states = prepare_data_template.prepare_data(args)
    sort_by_duration = (mode == "train")
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')
    save_feature_as_csvfile = args.config.getboolean(
        'train', 'save_feature_as_csvfile')
    if is_bucketing:
        buckets = json.loads(args.config.get('arch', 'buckets'))
        data_loaded = BucketSTTIter(
            partition="train",
            count=datagen.count,
            datagen=datagen,
            batch_size=batch_size,
            num_label=max_label_length,
            init_states=init_states,
            seq_length=max_t_count,
            width=whcs.width,
            height=whcs.height,
            sort_by_duration=sort_by_duration,
            is_bi_graphemes=is_bi_graphemes,
            buckets=buckets,
            save_feature_as_csvfile=save_feature_as_csvfile)
    else:
        data_loaded = STTIter(partition="train",
                              count=datagen.count,
                              datagen=datagen,
                              batch_size=batch_size,
                              num_label=max_label_length,
                              init_states=init_states,
                              seq_length=max_t_count,
                              width=whcs.width,
                              height=whcs.height,
                              sort_by_duration=sort_by_duration,
                              is_bi_graphemes=is_bi_graphemes,
                              save_feature_as_csvfile=save_feature_as_csvfile)

    if mode == 'train' or mode == 'load':
        if is_bucketing:
            validation_loaded = BucketSTTIter(
                partition="validation",
                count=datagen.val_count,
                datagen=datagen,
                batch_size=batch_size,
                num_label=max_label_length,
                init_states=init_states,
                seq_length=max_t_count,
                width=whcs.width,
                height=whcs.height,
                sort_by_duration=False,
                is_bi_graphemes=is_bi_graphemes,
                buckets=buckets,
                save_feature_as_csvfile=save_feature_as_csvfile)
        else:
            validation_loaded = STTIter(
                partition="validation",
                count=datagen.val_count,
                datagen=datagen,
                batch_size=batch_size,
                num_label=max_label_length,
                init_states=init_states,
                seq_length=max_t_count,
                width=whcs.width,
                height=whcs.height,
                sort_by_duration=False,
                is_bi_graphemes=is_bi_graphemes,
                save_feature_as_csvfile=save_feature_as_csvfile)
        return data_loaded, validation_loaded, args
    elif mode == 'predict':
        return data_loaded, args
Пример #51
0
    def unpack_nested(self, buff, separator=',', write_tag_name=True):
        message=Message()
        message.separator=separator
        message.has_tag_name=write_tag_name
        message.message_type=self.message_type
        
        (message_type, body_len)=self.header_struct.unpack_from(buff)
        if message_type!=self.message_type:
            #TODO:error format check
            LogUtil.error("Error message_type,expected="+str(self.message_type)+",act_type="+str(message_type))
            return
        #TODO:check buff size
        #TODO:check checksum       
        
        body_left_buff=buff[self.header_struct.size:]
        bytes_processed=0
        #body_tuple=self.body_struct.unpack_from(body_buff)
        str_value_list=[]
        rtn_str=''
        #field_list:
        #0:field_name;
        #1:format_str:
        #2:type,eg N,C;
        #3:type_len;
        #4:ref_field
        #5:struct_tag
        field_index=0
        #support we enounter a body begin tag
        (loop_cnt_field, loop_cnt, exec_cnt)=(0, 1, 1)
        (var_len_indictor_field, var_len)=(0, 0)
        loop_list=[]
        while field_index<len(self.field_list) and self.field_list[field_index][5]!='BODY_END':
            if self.field_list[field_index][5] in('DATA', 'LOOP_CNT', 'NEXT_VAR_LEN', 'VAR_LEN'):            
                if field_index>0:
                    rtn_str=rtn_str+separator
                if write_tag_name:
                    rtn_str=rtn_str+self.field_list[field_index][0]+'='
                
                if self.field_list[field_index][5]=='VAR_LEN':
                    #self.field_list[field_index][4]-1 is the index of field which holding the length
                    if self.field_list[field_index][4]!=var_len_indictor_field:
                        LogUtil.error("Error getting VAR_LEN:field_index="+str(field_index)+" indict field="+str(var_len_indictor_field))
                    format_str="!"+str(var_len)+"s"
                    bytes_len=var_len
                    LogUtil.debug("field:"+self.field_list[field_index][0]+" format:"+format_str)
                else:
                    format_str="!"+self.field_list[field_index][1]
                    bytes_len=self.field_list[field_index][3]
                    LogUtil.debug("field:"+self.field_list[field_index][0]+" format:"+format_str)
                field_value=struct.unpack_from(format_str, body_left_buff)
                
                if self.field_list[field_index][5]=='NEXT_VAR_LEN':
                    var_len=field_value[0]
                    var_len_indictor_field=field_index
                    
                if self.field_list[field_index][2]=='C':
                    field_str_value=bytes.decode(field_value[0])
                elif self.field_list[field_index][2]=='N':
                    field_str_value=str(field_value[0])
                else:
                    field_str_value=bytes.decode(field_value[0])
                    LogUtil.error("unknown type category:"+self.field_list[field_index][2])
                
                str_value_list.append(field_str_value)
                rtn_str=rtn_str+field_str_value
                #TODO add a column to message_body_def to indicate whether the column is client Order ID
                if self.field_list[field_index][0]=='ClOrdID':
                    message.client_order_id=field_str_value
                elif self.field_list[field_index][0]=='OrdStatus':
                    message.order_status=field_str_value
                elif self.field_list[field_index][0]=='OrdRejReason':
                    message.order_reject_reason=field_value[0]
                bytes_processed=bytes_processed+bytes_len    
                body_left_buff=body_left_buff[bytes_len:]
                
                if self.field_list[field_index][5]=='LOOP_CNT':
                    loop_list.append((field_index, loop_cnt, exec_cnt))
                    LogUtil.debug("push:field_index="+str(field_index)+ ",loop="+str(loop_cnt)+ ",exec="+str(exec_cnt))
                    loop_cnt=field_value[0]
                    exec_cnt=0
                field_index=field_index+1    
            elif self.field_list[field_index][5] in('LOOP_BEGIN'):
                if exec_cnt<loop_cnt:
                    exec_cnt=exec_cnt+1
                    field_index=field_index+1
                else: #end loop, move index to LOOP_END+1
                    (loop_cnt_field, loop_cnt, exec_cnt)=loop_list.pop()
                    LogUtil.debug("pop:field_index="+str(loop_cnt_field)+",loop="+str(loop_cnt)+ ",exec="+str(exec_cnt))
                    field_index=self.field_list[field_index][4]               

            elif self.field_list[field_index][5] in('LOOP_END'):
                #goto LOOP BEGIN(which is saved in ref_field)
                field_index=self.field_list[field_index][4]
                if exec_cnt>loop_cnt:
                    LogUtil.critical("LOOP_END error: exec_cnt="+str(exec_cnt)+",loop_cnt="+str(loop_cnt))
            else:
                #TODO:raiseExceptions
                LogUtil.error("unsupported struct tag:"+self.field_list[field_index][5])
                break
        message.message_str=rtn_str
        if bytes_processed!=body_len:
            LogUtil.error("bytes_process!=body_len,mesType="+str(message_type)+",body_len="+str(body_len)+",bytes_processed="+str(bytes_processed))
        return message
import socket
import time
import google.protobuf
import msg_pb2
import time
import sys
import traceback;
import threading;
from ini_op import Config;
from log_util import LogUtil;
from printer import Printer;
from PyQt5.QtNetwork import *;
#ini 文件
from PyQt5.QtPrintSupport import QPrinter;
import urllib.request;
log=LogUtil();
config = Config("config.ini")
port = 9999;#端口
# host = 'localhost';#OM server
host=config.get("baseconf", "oms_host")
BUFSIZE=8192

#设置连接超时 要保证客户的接收文件服务器不能断
#录音文件的接收器和应用服务器同一台,不会断
# socket.setdefaulttimeout(0.01)
#连接OM server

class OMClient(threading.Thread):
    def reconnect(self):
        try:
            self.client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Пример #53
0
def generate_news_data_by_jieba(origin_news_file_dir):
    """
        word_flag_dict:{word:flag, ...}
        news_ctg_title_data:[[ctg,[title words],[content words]], ...]
    """
    stopWords = StopWords(STOPWORD_FILE_NAME)
    log = LogUtil(NLP_LOG_NAME)

    all_news_ctg_title_data = []
    word_flag_dict = {}

    for file_name in os.listdir(origin_news_file_dir):
        if file_name[-4:] != 'json':
            continue

        log.log_info('now process origin news file is : %s' % file_name)
        origin_news_data = load_file(origin_news_file_dir + file_name)
        count = 0

        for news in origin_news_data:
            if count % 100 == 0:
                log.log_info(
                    'now process origin news file is : %s, now processed news number is %s'
                    % (file_name, count))
            count += 1

            news_ctg_title_data = []
            news_json = json.loads(news, 'utf-8')
            ctg = news_json['category'].split('_')[2]
            title = news_json['title']
            content = news_json.get('content', '')

            title_words = util.split_sentence_to_word(title)
            title_words_filter_stopwords = []
            title_word_list = []
            for item in title_words:
                if not stopWords.is_stopword(item[0]):
                    title_words_filter_stopwords.append(item)
                    word_flag_dict[item[0]] = item[1]
                    title_word_list.append(item[0])

            content_words = util.split_sentence_to_word(content)
            content_words_filter_stopwords = []
            content_word_list = []
            for item in content_words:
                if not stopWords.is_stopword(item[0]):
                    content_words_filter_stopwords.append(item)
                    word_flag_dict[item[0]] = item[1]
                    content_word_list.append(item[0])

            news_ctg_title_data.append(ctg)
            news_ctg_title_data.append(title_word_list)
            news_ctg_title_data.append(content_word_list)
            all_news_ctg_title_data.append(news_ctg_title_data)

    log.log_info('start to save word_flag_dict')
    util.save_data_by_cPickle(word_flag_dict, WORD_FLAG_DICT_NAME)
    log.log_info('save word_flag_dict success')

    log.log_info('start to save all_news_ctg_title_data')
    util.save_data_by_cPickle(word_flag_dict,
                              ALL_NEWS_CTG_TITLE_CONTENT_FILE_NAME)
    log.log_info('save all_news_ctg_title_data success')
Пример #54
0
            model_loaded = mx.module.Module.load(prefix=model_path, epoch=model_num_epoch, context=contexts,
                                                 data_names=data_names, label_names=label_names,
                                                 load_optimizer_states=False)

    return model_loaded, model_num_epoch


if __name__ == '__main__':
    if len(sys.argv) <= 1:
        raise Exception('cfg file path must be provided. ex)python main.py --configfile examplecfg.cfg')
    mx.random.seed(hash(datetime.now()))
    # set parameters from cfg file
    args = parse_args(sys.argv[1])

    log_filename = args.config.get('common', 'log_filename')
    log = LogUtil(filename=log_filename).getlogger()

    # set parameters from data section(common)
    mode = args.config.get('common', 'mode')
    if mode not in ['train', 'predict', 'load']:
        raise Exception(
            'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.')

    # get meta file where character to number conversions are defined
    language = args.config.get('data', 'language')
    labelUtil = LabelUtil.getInstance()
    if language == "en":
        labelUtil.load_unicode_set("resources/unicodemap_en_baidu.csv")
    else:
        raise Exception("Error: Language Type: %s" % language)
    args.config.set('arch', 'n_classes', str(labelUtil.get_count()))
Пример #55
0
import socket
import time
import google.protobuf
import msg_pb2
import time
import sys
import traceback;
import threading;
from ini_op import Config;
from log_util import LogUtil;
from printer import Printer;
from downloader import Downloader;
#ini 文件
from PyQt5.QtPrintSupport import QPrinter;

log=LogUtil();
config = Config("config.ini")
port = 9999;#端口
# host = 'localhost';#OM server
host=config.get("baseconf", "oms_host")
BUFSIZE=8192

#设置连接超时
socket.setdefaulttimeout(0.01)
#连接OM server
class OMClient(threading.Thread):
    def reconnect(self):
        try:
            self.client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.client.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)  # 在客户端开启心跳维护
            self.client.connect((host, port))