Пример #1
0
def db_reader(name, status, req_queue, config_file):
    LogUtil.get_instance(config_file, "db_reader")
    LogUtil.info("db_reader:"+name+" begin")

    config=configparser.ConfigParser()
    config.read(config_file)
    
    factory=msg.MessageProcessorFactory()
    db_file=config.get("message_config", "db_file")
    factory.load_from_db(db_file, [])
    
    read_interval=config.getfloat("reqresp","read_interval")
    if read_interval==None:
        read_interval=0.5
    
    host=config.get("reqresp", "host")
    database=config.get("reqresp", "database")
    user=config.get("reqresp","user")
    password=config.get("reqresp", "password")
    
    last_req_num=0
    conn=pgdb.connect(database=database, host=host, user=user, password=password)
    query_curs=conn.cursor()
    update_curs=conn.cursor()
    query_unreported="""SELECT req_num,message_type,appid,oms_order_id,rept_status,req_text
        FROM req_resp
        WHERE rept_status='0'
        AND req_num>%(req_num)s
        ORDER BY req_num
    """
    query_dict={'req_num':0}
    
    update_reported="""UPDATE req_resp 
    SET rept_status=%(rept_status)s,report_time=localtimestamp
    WHERE req_num in (%(req_num)s)    
    """
    update_dict={'rept_status':'0', 'req_num':0}
    
    last_read_cnt=1
    while status.value==0:
        if last_read_cnt==0:
            time.sleep(read_interval)

        last_read_cnt=0
        query_dict['req_num']=last_req_num
        query_curs.execute(query_unreported,query_dict)
        for (req_num,message_type,appid,message_id,rept_status,req_text) in query_curs.fetchall():
            message_processor=factory.build_message_processor(message_type)
            send_buff=message_processor.pack(req_text)
            req_queue.put(send_buff)
            last_req_num=req_num
            update_dict['rept_status']='2'
            update_dict['req_num']=last_req_num
            
            update_curs.execute(update_reported,update_dict)
            last_read_cnt=last_read_cnt+1
            LogUtil.debug("db_reader putQ:"+binascii.hexlify(send_buff).decode())       
        conn.commit()    
        
    LogUtil.info("db_reader:"+name+" end")
Пример #2
0
def _get_lr_scheduler(args, kv):
    learning_rate = args.config.getfloat('train', 'learning_rate')
    lr_factor = args.config.getfloat('train', 'lr_factor')
    if lr_factor >= 1:
        return (learning_rate, None)
    epoch_size = args.num_examples / args.batch_size
    if 'dist' in args.kv_store:
        epoch_size /= kv.num_workers
    mode = args.config.get('common', 'mode')
    begin_epoch = 0
    if mode == "load":
        model_file = args.config.get('common', 'model_file')
        begin_epoch = int(
            model_file.split("-")[1]) if len(model_file) == 16 else int(
                model_file.split("n_epoch")[1].split("n_batch")[0])
    step_epochs = [
        int(l) for l in args.config.get('train', 'lr_step_epochs').split(',')
    ]
    for s in step_epochs:
        if begin_epoch >= s:
            learning_rate *= lr_factor
    if learning_rate != args.config.getfloat('train', 'learning_rate'):
        log = LogUtil().getlogger()
        log.info('Adjust learning rate to %e for epoch %d' %
                 (learning_rate, begin_epoch))

    steps = [
        epoch_size * (x - begin_epoch) for x in step_epochs
        if x - begin_epoch > 0
    ]
    return (learning_rate,
            mx.lr_scheduler.MultiFactorScheduler(step=steps,
                                                 factor=args.lr_factor))
Пример #3
0
 def load_from_db(self, db_file, message_type_list):
     conn=sqlite3.connect(db_file)
     curs=conn.cursor()
     if len(message_type_list)==0:
         curs.execute('select distinct message_type from message_body_def order by message_type')
         for row in curs.fetchall():
             message_type_list.append(row[0])
     query='''select t.message_type,t.field_name,t.field_desc,t.format_string,t.type_category,t.type_len,t.ref_field,t.field_order,t.struct_tag 
             from vw_message_body_def t
             where t.message_type=?
             order by t.field_order
             '''
     for message_type in message_type_list:
         field_list=[]
         field_cnt=0
         curs.execute(query, [message_type])
         for row in curs.fetchall():
             field_cnt=field_cnt+1
             LogUtil.info(row)
             if row[7]!=field_cnt-1:
                 #TODO:raiseExceptions
                 LogUtil.error("field list disorder:mesType="+str(message_type)+" field:"+str(field_cnt))
             field_list.append((row[1], row[3], row[4], row[5], row[6], row[8]))
         message_processor=MessageProcessor(message_type, field_list)
         self.message_processors[message_type]=message_processor
Пример #4
0
    def sample_normalize(self, k_samples=1000, overwrite=False):
        """ Estimate the mean and std of the features from the training set
        Params:
            k_samples (int): Use this number of samples for estimation
        """
        log = LogUtil().getlogger()
        log.info("Calculating mean and std from samples")
        # if k_samples is negative then it goes through total dataset
        if k_samples < 0:
            audio_paths = self.audio_paths

        # using sample
        else:
            k_samples = min(k_samples, len(self.train_audio_paths))
            samples = self.rng.sample(self.train_audio_paths, k_samples)
            audio_paths = samples
        manager = Manager()
        return_dict = manager.dict()
        jobs = []
        num_processes = min(len(audio_paths), cpu_count())
        split_size = int(
            math.ceil(float(len(audio_paths)) / float(num_processes)))
        audio_paths_split = []
        for i in range(0, len(audio_paths), split_size):
            audio_paths_split.append(audio_paths[i:i + split_size])

        for thread_index in range(num_processes):
            proc = Process(target=self.preprocess_sample_normalize,
                           args=(thread_index, audio_paths_split[thread_index],
                                 overwrite, return_dict))
            jobs.append(proc)
            proc.start()
        for proc in jobs:
            proc.join()

        feat = np.sum(np.vstack(
            [item['feat'] for item in return_dict.values()]),
                      axis=0)
        count = sum([item['count'] for item in return_dict.values()])
        print(feat, count)
        feat_squared = np.sum(np.vstack(
            [item['feat_squared'] for item in return_dict.values()]),
                              axis=0)

        self.feats_mean = feat / float(count)
        self.feats_std = np.sqrt(feat_squared / float(count) -
                                 np.square(self.feats_mean))
        np.savetxt(
            generate_file_path(self.save_dir, self.model_name, 'feats_mean'),
            self.feats_mean)
        np.savetxt(
            generate_file_path(self.save_dir, self.model_name, 'feats_std'),
            self.feats_std)
        log.info("End calculating mean and std from samples")
Пример #5
0
    def load_metadata_from_desc_file(
        self,
        desc_file,
        partition='train',
        max_duration=16.0,
    ):
        """ Read metadata from the description file
            (possibly takes long, depending on the filesize)
        Params:
            desc_file (str):  Path to a JSON-line file that contains labels and
                paths to the audio files
            partition (str): One of 'train', 'validation' or 'test'
            max_duration (float): In seconds, the maximum duration of
                utterances to train or test on
        """
        logger = LogUtil().getlogger()
        logger.info('Reading description file: {} for partition: {}'.format(
            desc_file, partition))
        audio_paths, durations, texts = [], [], []
        with open(desc_file) as json_line_file:
            for line_num, json_line in enumerate(json_line_file):
                try:
                    spec = json.loads(json_line)
                    if float(spec['duration']) > max_duration:
                        continue
                    audio_paths.append(spec['key'])
                    durations.append(float(spec['duration']))
                    texts.append(spec['text'])
                except Exception as e:
                    # Change to (KeyError, ValueError) or
                    # (KeyError,json.decoder.JSONDecodeError), depending on
                    # json module version
                    logger.warn('Error reading line #{}: {}'.format(
                        line_num, json_line))
                    logger.warn(str(e))

        if partition == 'train':
            self.count = len(audio_paths)
            self.train_audio_paths = audio_paths
            self.train_durations = durations
            self.train_texts = texts
        elif partition == 'validation':
            self.val_audio_paths = audio_paths
            self.val_durations = durations
            self.val_texts = texts
            self.val_count = len(audio_paths)
        elif partition == 'test':
            self.test_audio_paths = audio_paths
            self.test_durations = durations
            self.test_texts = texts
        else:
            raise Exception("Invalid partition to load metadata. "
                            "Must be train/validation/test")
Пример #6
0
def md_requestor(name, status, req_queue, config_file):
    LogUtil.get_instance(config_file, "db_reader")
    LogUtil.info("db_reader:"+name+" begin")

    config = configparser.ConfigParser()
    config.read(config_file)

    factory = msg.MessageProcessorFactory()
    db_file = config.get("message_config", "db_file")
    factory.load_from_db(db_file, [])

    trade_db_file = config.get("reqresp", "db_file")
    read_interval = config.getfloat("reqresp", "read_interval")
    if read_interval is None:
        read_interval = 0.5

    last_req_num = 0
    conn = sqlite3.connect(trade_db_file)
    query_curs = conn.cursor()
    update_curs = conn.cursor()
    query_unreported = """SELECT reqnum,message_type,appid,oms_order_id,order_status,req_text
        FROM req_resp
        WHERE order_status='0'
        AND reqnum>?
        ORDER BY reqnum
    """
    update_reported = """UPDATE req_resp
    SET order_status=?,report_time=strftime('%Y-%m-%d %H:%M:%f','now')
    WHERE reqnum in(?)
    """
    last_read_cnt = 1
    while status.value == 0:
        if last_read_cnt == 0:
            time.sleep(read_interval)

        last_read_cnt = 0
        query_curs.execute(query_unreported, [last_req_num])
        for (reqnum, message_type, appid, message_id, order_status, req_text) in query_curs.fetchall():
            message_processor = factory.build_message_processor(message_type)
            send_buff = message_processor.pack(req_text)
            req_queue.put(send_buff)
            last_req_num = reqnum
            update_curs.execute(update_reported, ['2', reqnum])
            last_read_cnt = last_read_cnt+1
            if send_buff:
                LogUtil.debug("db_reader putQ:"+binascii.hexlify(send_buff).decode())
            else:
                LogUtil.debug("db_reader putQ: send_buff NULL")
        conn.commit()

    LogUtil.info("db_reader:"+name+" end")
Пример #7
0
    def update(self, labels, preds):
        check_label_shapes(labels, preds)
        if self.is_logging:
            log = LogUtil().getlogger()
            labelUtil = LabelUtil()
        self.batch_loss = 0.
        # log.info(self.audio_paths)
        host_name = socket.gethostname()
        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()

            seq_length = len(pred) / int(
                int(self.batch_size) / int(self.num_gpu))

            for i in range(int(int(self.batch_size) / int(self.num_gpu))):
                l = remove_blank(label[i])
                p = []
                probs = []
                for k in range(int(seq_length)):
                    p.append(
                        np.argmax(pred[
                            k * int(int(self.batch_size) / int(self.num_gpu)) +
                            i]))
                    probs.append(
                        pred[k * int(int(self.batch_size) / int(self.num_gpu))
                             + i])
                p = pred_best(p)

                l_distance = levenshtein_distance(l, p)
                # l_distance = editdistance.eval(labelUtil.convert_num_to_word(l).split(" "), res)
                self.total_n_label += len(l)
                self.total_l_dist += l_distance
                this_cer = float(l_distance) / float(len(l))
                if self.is_logging and this_cer > 0.4:
                    log.info("%s label: %s " %
                             (host_name, labelUtil.convert_num_to_word(l)))
                    log.info(
                        "%s pred : %s , cer: %f (distance: %d/ label length: %d)"
                        % (host_name, labelUtil.convert_num_to_word(p),
                           this_cer, l_distance, len(l)))
                    # log.info("ctc_loss: %.2f" % ctc_loss(l, pred, i, int(seq_length), int(self.batch_size), int(self.num_gpu)))
                self.num_inst += 1
                self.sum_metric += this_cer
                # if self.is_epoch_end:
                #    loss = ctc_loss(l, pred, i, int(seq_length), int(self.batch_size), int(self.num_gpu))
                #    self.batch_loss += loss
                #    if self.is_logging:
                #        log.info("loss: %f " % loss)
        self.total_ctc_loss += 0  # self.batch_loss
Пример #8
0
    def load_metadata_from_desc_file(self, desc_file, partition='train',
                                     max_duration=16.0,):
        """ Read metadata from the description file
            (possibly takes long, depending on the filesize)
        Params:
            desc_file (str):  Path to a JSON-line file that contains labels and
                paths to the audio files
            partition (str): One of 'train', 'validation' or 'test'
            max_duration (float): In seconds, the maximum duration of
                utterances to train or test on
        """
        logger = LogUtil().getlogger()
        logger.info('Reading description file: {} for partition: {}'
                    .format(desc_file, partition))
        audio_paths, durations, texts = [], [], []
        with open(desc_file) as json_line_file:
            for line_num, json_line in enumerate(json_line_file):
                try:
                    spec = json.loads(json_line)
                    if float(spec['duration']) > max_duration:
                        continue
                    audio_paths.append(spec['key'])
                    durations.append(float(spec['duration']))
                    texts.append(spec['text'])
                except Exception as e:
                    # Change to (KeyError, ValueError) or
                    # (KeyError,json.decoder.JSONDecodeError), depending on
                    # json module version
                    logger.warn('Error reading line #{}: {}'
                                .format(line_num, json_line))
                    logger.warn(str(e))

        if partition == 'train':
            self.count = len(audio_paths)
            self.train_audio_paths = audio_paths
            self.train_durations = durations
            self.train_texts = texts
        elif partition == 'validation':
            self.val_audio_paths = audio_paths
            self.val_durations = durations
            self.val_texts = texts
            self.val_count = len(audio_paths)
        elif partition == 'test':
            self.test_audio_paths = audio_paths
            self.test_durations = durations
            self.test_texts = texts
        else:
            raise Exception("Invalid partition to load metadata. "
                            "Must be train/validation/test")
Пример #9
0
def tgw_recv(name, status, sock, resp_queue, config_file):
    LogUtil.get_instance(config_file, "tgw_recv")
    LogUtil.info("tgw_recv:"+name+" begin")
    
    while status.value==0:
        try:
            recv_data = sock.recv(1024)
            if not recv_data:
                LogUtil.error('Recv message error!')
            else:
                LogUtil.debug('tgw recv:'+binascii.hexlify(recv_data).decode())
                #to make the recv faster, do NOT process more, just put the message to the queue
                resp_queue.put(recv_data)
        finally:
            pass
            #LogUtil.debug("")

    LogUtil.info("tgw_recv:"+name+" end")
Пример #10
0
def mdgw_recv(name, status, sock, resp_queue, config_file):
    LogUtil.get_instance(config_file, "mdgw_recv")
    LogUtil.info("mdgw_recv:"+name+" begin")
    buf_id = 0
    while status.value == 0:
        try:
            recv_data = sock.recv(1024)
            if not recv_data:
                LogUtil.error('Recv message error!')
            else:
                buf_id = buf_id+1
                src_time = datetime.datetime.now()
                # to make the recv faster, do NOT process more, just put the message to the queue
                resp_queue.put((buf_id, src_time, recv_data))
                LogUtil.debug('mdgw recv:'+binascii.hexlify(recv_data).decode())
        finally:
            pass
            # LogUtil.debug("")
    LogUtil.info("mdgw_recv:"+name+" end")
Пример #11
0
def test_unpack(buff_str):
    factory=msg.MessageProcessorFactory()
    factory.load_from_db('message_config.s3db', [])
    left_buff=binascii.a2b_hex(buff_str)
    if len(left_buff)<Message.header_len+Message.header_len:
        LogUtil.info("next message not ready1")
    (message_type, body_len)=msg.get_message_header(left_buff)
    next_message_len=body_len+ Message.header_len+Message.footer_len
    while next_message_len<=len(left_buff):
        message_processor=factory.build_message_processor(message_type)
        message=message_processor.unpack(left_buff)
        LogUtil.debug("message:"+message.toString())
        left_buff=left_buff[next_message_len:]
        if len(left_buff)<Message.header_len+Message.footer_len:
            LogUtil.debug("break @left size:"+str(len(left_buff)))
            break
        else:
            (message_type, body_len)=msg.get_message_header(left_buff)
            next_message_len=body_len+ Message.header_len+Message.footer_len
            LogUtil.debug("MsgType="+str(message_type)+" body_len="+str(body_len))
    LogUtil.debug("left buff:"+binascii.hexlify(left_buff).decode())
Пример #12
0
    def update(self, labels, preds):
        check_label_shapes(labels, preds)

        log = LogUtil().getlogger()
        labelUtil = LabelUtil.getInstance()

        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()

            for i in range(int(int(self.batch_size) / int(self.num_gpu))):

                l = remove_blank(label[i])
                p = []
                for k in range(int(self.seq_length)):
                    p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i]))
                p = pred_best(p)

                l_distance = levenshtein_distance(l, p)
                self.total_n_label += len(l)
                self.total_l_dist += l_distance
                this_cer = float(l_distance) / float(len(l))
                log.info("label: %s " % (labelUtil.convert_num_to_word(l)))
                log.info("pred : %s , cer: %f (distance: %d/ label length: %d)" % (
                    labelUtil.convert_num_to_word(p), this_cer, l_distance, len(l)))
                self.num_inst += 1
                self.sum_metric += this_cer
            if self.is_epoch_end:
                loss = ctc_loss(l, pred, i, int(self.seq_length), int(self.batch_size), int(self.num_gpu))
                self.total_ctc_loss += loss
                log.info("loss: %f " % loss)
Пример #13
0
    def update(self, labels, preds):
        check_label_shapes(labels, preds)
        if self.is_logging:
            log = LogUtil().getlogger()
            labelUtil = LabelUtil.getInstance()
        self.batch_loss = 0.
        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()

            for i in range(int(int(self.batch_size) / int(self.num_gpu))):

                l = remove_blank(label[i])
                p = []
                for k in range(int(self.seq_length)):
                    p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i]))
                p = pred_best(p)

                l_distance = levenshtein_distance(l, p)
                self.total_n_label += len(l)
                self.total_l_dist += l_distance
                this_cer = float(l_distance) / float(len(l))
                if self.is_logging:
                    log.info("label: %s " % (labelUtil.convert_num_to_word(l)))
                    log.info("pred : %s , cer: %f (distance: %d/ label length: %d)" % (
                        labelUtil.convert_num_to_word(p), this_cer, l_distance, len(l)))
                self.num_inst += 1
                self.sum_metric += this_cer
                if self.is_epoch_end:
                    loss = ctc_loss(l, pred, i, int(self.seq_length), int(self.batch_size), int(self.num_gpu))
                    self.batch_loss += loss
                    if self.is_logging:
                        log.info("loss: %f " % loss)
        self.total_ctc_loss += self.batch_loss
Пример #14
0
    def sample_normalize(self, k_samples=1000, overwrite=False):
        """ Estimate the mean and std of the features from the training set
        Params:
            k_samples (int): Use this number of samples for estimation
        """
        log = LogUtil().getlogger()
        log.info("Calculating mean and std from samples")
        # if k_samples is negative then it goes through total dataset
        if k_samples < 0:
            audio_paths = self.audio_paths

        # using sample
        else:
            k_samples = min(k_samples, len(self.train_audio_paths))
            samples = self.rng.sample(self.train_audio_paths, k_samples)
            audio_paths = samples
        manager = Manager()
        return_dict = manager.dict()
        jobs = []
        for threadIndex in range(cpu_count()):
            proc = Process(target=self.preprocess_sample_normalize, args=(threadIndex, audio_paths, overwrite, return_dict))
            jobs.append(proc)
            proc.start()
        for proc in jobs:
            proc.join()

        feat = np.sum(np.vstack([item['feat'] for item in return_dict.values()]), axis=0)
        count = sum([item['count'] for item in return_dict.values()])
        feat_squared = np.sum(np.vstack([item['feat_squared'] for item in return_dict.values()]), axis=0)

        self.feats_mean = feat / float(count)
        self.feats_std = np.sqrt(feat_squared / float(count) - np.square(self.feats_mean))
        np.savetxt(
            generate_file_path(self.save_dir, self.model_name, 'feats_mean'), self.feats_mean)
        np.savetxt(
            generate_file_path(self.save_dir, self.model_name, 'feats_std'), self.feats_std)
        log.info("End calculating mean and std from samples")
Пример #15
0
def do_training(args, module, data_train, data_val, begin_epoch=0, kv=None):
    from distutils.dir_util import mkpath

    host_name = socket.gethostname()
    log = LogUtil().getlogger()
    mkpath(os.path.dirname(config_util.get_checkpoint_path(args)))

    # seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    val_batch_size = args.config.getint('common', 'val_batch_size')
    save_checkpoint_every_n_epoch = args.config.getint(
        'common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint(
        'common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean(
        'train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean(
        'train', 'enable_logging_validation_metric')

    contexts = config_util.parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=val_batch_size,
                            num_gpu=num_gpu,
                            is_logging=enable_logging_validation_metric,
                            is_epoch_end=True)
    # tensorboard setting
    loss_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            is_logging=enable_logging_train_metric,
                            is_epoch_end=False)

    optimizer = args.config.get('optimizer', 'optimizer')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_start = args.config.getfloat('train', 'learning_rate_start')
    learning_rate_annealing = args.config.getfloat('train',
                                                   'learning_rate_annealing')
    lr_factor = args.config.getfloat('train', 'lr_factor')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('optimizer', 'clip_gradient')
    weight_decay = args.config.getfloat('optimizer', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train',
                                                   'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    optimizer_params_dictionary = json.loads(
        args.config.get('optimizer', 'optimizer_params_dictionary'))
    kvstore_option = args.config.get('common', 'kvstore_option')
    n_epoch = begin_epoch
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')

    # kv = mx.kv.create(kvstore_option)
    # data = mx.io.ImageRecordIter(num_parts=kv.num_workers, part_index=kv.rank)
    # # a.set_optimizer(optimizer)
    # updater = mx.optimizer.get_updater(optimizer)
    # a._set_updater(updater=updater)

    if clip_gradient == 0:
        clip_gradient = None
    if is_bucketing and mode == 'load':
        model_file = args.config.get('common', 'model_file')
        model_name = os.path.splitext(model_file)[0]
        model_num_epoch = int(model_name[-4:])

        model_path = 'checkpoints/' + str(model_name[:-5])
        prefix = args.config.get('common', 'prefix')
        if os.path.isabs(prefix):
            model_path = config_util.get_checkpoint_path(args).rsplit(
                "/", 1)[0] + "/" + str(model_name[:-5])
        # symbol, data_names, label_names = module(1600)
        model = mx.mod.BucketingModule(
            sym_gen=module,
            default_bucket_key=data_train.default_bucket_key,
            context=contexts)
        data_train.reset()

        model.bind(data_shapes=data_train.provide_data,
                   label_shapes=data_train.provide_label,
                   for_training=True)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            model_path, model_num_epoch)

        # arg_params2 = {}
        # for item in arg_params.keys():
        #     if not item.startswith("forward") and not item.startswith("backward") and not item.startswith("rear"):
        #         arg_params2[item] = arg_params[item]
        # model.set_params(arg_params2, aux_params, allow_missing=True, allow_extra=True)

        model.set_params(arg_params, aux_params)
        module = model
    else:
        module.bind(data_shapes=data_train.provide_data,
                    label_shapes=data_train.provide_label,
                    for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))
        # model_file = args.config.get('common', 'model_file')
        # model_name = os.path.splitext(model_file)[0]
        # model_num_epoch = int(model_name[-4:])
        # model_path = 'checkpoints/' + str(model_name[:-5])
        # _, arg_params, aux_params = mx.model.load_checkpoint(model_path, model_num_epoch)
        # arg_params2 = {}
        # for item in arg_params.keys():
        #     if not item.startswith("forward") and not item.startswith("backward") and not item.startswith("rear"):
        #         arg_params2[item] = arg_params[item]
        # module.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True)

    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    # lr, lr_scheduler = _get_lr_scheduler(args, kv)

    def reset_optimizer(force_init=False):
        optimizer_params = {
            'lr_scheduler': lr_scheduler,
            'clip_gradient': clip_gradient,
            'wd': weight_decay
        }
        optimizer_params.update(optimizer_params_dictionary)
        module.init_optimizer(kvstore=kv,
                              optimizer=optimizer,
                              optimizer_params=optimizer_params,
                              force_init=force_init)

    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)
        data_train.reset()
        data_train.is_first_epoch = True

    # tensorboard setting
    tblog_dir = args.config.get('common', 'tensorboard_log_dir')
    summary_writer = SummaryWriter(tblog_dir)
    learning_rate_pre = 0
    while True:

        if n_epoch >= num_epoch:
            break
        loss_metric.reset()
        log.info(host_name + '---------train---------')

        step_epochs = [
            int(l)
            for l in args.config.get('train', 'lr_step_epochs').split(',')
        ]
        # warm up to step_epochs[0] if step_epochs[0] > 0
        if n_epoch < step_epochs[0]:
            learning_rate_cur = learning_rate_start + n_epoch * (
                learning_rate - learning_rate_start) / step_epochs[0]
        else:
            # scaling lr every epoch
            if len(step_epochs) == 1:
                learning_rate_cur = learning_rate
                for s in range(n_epoch):
                    learning_rate_cur /= learning_rate_annealing
            # scaling lr by step_epochs[1:]
            else:
                learning_rate_cur = learning_rate
                for s in step_epochs[1:]:
                    if n_epoch > s:
                        learning_rate_cur *= lr_factor

        if learning_rate_pre and args.config.getboolean(
                'train', 'momentum_correction'):
            lr_scheduler.learning_rate = learning_rate_cur * learning_rate_cur / learning_rate_pre
        else:
            lr_scheduler.learning_rate = learning_rate_cur
        learning_rate_pre = learning_rate_cur
        log.info("n_epoch %d's lr is %.7f" %
                 (n_epoch, lr_scheduler.learning_rate))
        summary_writer.add_scalar('lr', lr_scheduler.learning_rate, n_epoch)
        for nbatch, data_batch in enumerate(data_train):

            module.forward_backward(data_batch)
            module.update()
            # tensorboard setting
            if (nbatch + 1) % show_every == 0:
                # loss_metric.set_audio_paths(data_batch.index)
                module.update_metric(loss_metric, data_batch.label)
                # print("loss=========== %.2f" % loss_metric.get_batch_loss())
            # summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch + 1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch,
                         nbatch)
                save_checkpoint(
                    module,
                    prefix=config_util.get_checkpoint_path(args) + "n_epoch" +
                    str(n_epoch) + "n_batch",
                    epoch=(int(
                        (nbatch + 1) / save_checkpoint_every_n_batch) - 1),
                    save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info(host_name + '---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            eval_metric.set_audio_paths(data_batch.index)
            module.update_metric(eval_metric, data_batch.label)

        # tensorboard setting
        val_cer, val_n_label, val_l_dist, val_ctc_loss = eval_metric.get_name_value(
        )
        log.info("Epoch[%d] val cer=%f (%d / %d), ctc_loss=%f", n_epoch,
                 val_cer, int(val_n_label - val_l_dist), val_n_label,
                 val_ctc_loss)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        summary_writer.add_scalar('loss validation', val_ctc_loss, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'
        np.random.seed(n_epoch)
        data_train.reset()
        data_train.is_first_epoch = False

        # tensorboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value(
        )
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            save_checkpoint(module,
                            prefix=config_util.get_checkpoint_path(args),
                            epoch=n_epoch,
                            save_optimizer_states=save_optimizer_states)

        n_epoch += 1

    log.info('FINISH')
Пример #16
0
        res = ctc_greedy_decode(probs, self.labelUtil.byList)
        log.info("greedy decode cost %.3f, result is:\n%s" %
                 (time.time() - st, res))
        beam_size = 5
        from stt_metric import ctc_beam_decode
        st = time.time()
        results = ctc_beam_decode(scorer=self.scorer,
                                  beam_size=beam_size,
                                  vocab=self.labelUtil.byList,
                                  probs=probs)
        log.info("beam decode cost %.3f, result is:\n%s" %
                 (time.time() - st, "\n".join(results)))
        return "greedy:\n" + res + "\nbeam:\n" + "\n".join(results)


if __name__ == '__main__':
    if len(sys.argv) <= 1:
        raise Exception('cfg file path must be provided. ' +
                        'ex)python main.py --configfile examplecfg.cfg')
    args = parse_args(sys.argv[1])
    # set log file name
    log_filename = args.config.get('common', 'log_filename')
    log = LogUtil(filename=log_filename).getlogger()
    otherNet = Net(args)

    server = HTTPServer(('', args.config.getint('common', 'port')),
                        SimpleHTTPRequestHandler)
    log.info('Started httpserver on port')
    # Wait forever for incoming htto requests
    server.serve_forever()
Пример #17
0
            model.set_params(arg_params, aux_params, allow_missing=True)
            model_loaded = model
        else:
            model_loaded.bind(for_training=False,
                              data_shapes=data_train.provide_data,
                              label_shapes=data_train.provide_label)
        max_t_count = args.config.getint('arch', 'max_t_count')

        eval_metric = EvalSTTMetric(batch_size=batch_size,
                                    num_gpu=num_gpu,
                                    scorer=get_scorer())
        if is_batchnorm:
            st = time.time()
            result = []
            for p in random_search():
                log.info("alpha %s, beta %s" % (p.get("alpha"), p.get("beta")))
                eval_metric = EvalSTTMetric(batch_size=batch_size,
                                            num_gpu=num_gpu,
                                            scorer=get_scorer(
                                                alpha=p.get("alpha"),
                                                beta=p.get("beta")))
                for nbatch, data_batch in enumerate(data_train):
                    st1 = time.time()
                    model_loaded.forward(data_batch, is_train=False)
                    log.info("forward spent is %.2fs" % (time.time() - st1))
                    model_loaded.update_metric(eval_metric, data_batch.label)
                val_cer, val_cer_beam, val_n_label, val_l_dist, val_l_dist_beam, val_ctc_loss = eval_metric.get_name_value(
                )
                log.info(
                    "val cer=%f (%d / %d), cer_beam=%f (%d/%d) ctc_loss=%f",
                    val_cer, int(val_n_label - val_l_dist),
Пример #18
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint('common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean('train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean('train', 'enable_logging_validation_metric')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_validation_metric,is_epoch_end=True)
    # tensorboard setting
    loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_train_metric,is_epoch_end=False)

    optimizer = args.config.get('train', 'optimizer')
    momentum = args.config.getfloat('train', 'momentum')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('train', 'clip_gradient')
    weight_decay = args.config.getfloat('train', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    n_epoch=begin_epoch

    if clip_gradient == 0:
        clip_gradient = None

    module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))


    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    def reset_optimizer(force_init=False):
        if optimizer == "sgd":
            module.init_optimizer(kvstore='device',
                                  optimizer=optimizer,
                                  optimizer_params={'lr_scheduler': lr_scheduler,
                                                    'momentum': momentum,
                                                    'clip_gradient': clip_gradient,
                                                    'wd': weight_decay},
                                  force_init=force_init)
        elif optimizer == "adam":
            module.init_optimizer(kvstore='device',
                                  optimizer=optimizer,
                                  optimizer_params={'lr_scheduler': lr_scheduler,
                                                    #'momentum': momentum,
                                                    'clip_gradient': clip_gradient,
                                                    'wd': weight_decay},
                                  force_init=force_init)
        else:
            raise Exception('Supported optimizers are sgd and adam. If you want to implement others define them in train.py')
    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)

    #tensorboard setting
    tblog_dir = args.config.get('common', 'tensorboard_log_dir')
    summary_writer = SummaryWriter(tblog_dir)
    while True:

        if n_epoch >= num_epoch:
            break

        loss_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):

            module.forward_backward(data_batch)
            module.update()
            # tensorboard setting
            if (nbatch + 1) % show_every == 0:
                module.update_metric(loss_metric, data_batch.label)
            #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch+1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch)
                module.save_checkpoint(prefix=get_checkpoint_path(args)+"n_epoch"+str(n_epoch)+"n_batch", epoch=(int((nbatch+1)/save_checkpoint_every_n_batch)-1), save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            module.update_metric(eval_metric, data_batch.label)

        # tensorboard setting
        val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value()
        log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'

        data_train.reset()

        # tensorboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value()
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states)

        n_epoch += 1

        lr_scheduler.learning_rate=learning_rate/learning_rate_annealing

    log.info('FINISH')
Пример #19
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    #seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint('common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean('train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean('train', 'enable_logging_validation_metric')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_validation_metric,is_epoch_end=True)
    # mxboard setting
    loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_train_metric,is_epoch_end=False)

    optimizer = args.config.get('optimizer', 'optimizer')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('optimizer', 'clip_gradient')
    weight_decay = args.config.getfloat('optimizer', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    optimizer_params_dictionary = json.loads(args.config.get('optimizer', 'optimizer_params_dictionary'))
    kvstore_option = args.config.get('common', 'kvstore_option')
    n_epoch=begin_epoch
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')

    if clip_gradient == 0:
        clip_gradient = None
    if is_bucketing and mode == 'load':
        model_file = args.config.get('common', 'model_file')
        model_name = os.path.splitext(model_file)[0]
        model_num_epoch = int(model_name[-4:])

        model_path = 'checkpoints/' + str(model_name[:-5])
        symbol, data_names, label_names = module(1600)
        model = STTBucketingModule(
            sym_gen=module,
            default_bucket_key=data_train.default_bucket_key,
            context=contexts)
        data_train.reset()

        model.bind(data_shapes=data_train.provide_data,
                   label_shapes=data_train.provide_label,
                   for_training=True)
        _, arg_params, aux_params = mx.model.load_checkpoint(model_path, model_num_epoch)
        model.set_params(arg_params, aux_params)
        module = model
    else:
        module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))


    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    def reset_optimizer(force_init=False):
        optimizer_params = {'lr_scheduler': lr_scheduler,
                            'clip_gradient': clip_gradient,
                            'wd': weight_decay}
        optimizer_params.update(optimizer_params_dictionary)
        module.init_optimizer(kvstore=kvstore_option,
                              optimizer=optimizer,
                              optimizer_params=optimizer_params,
                              force_init=force_init)
    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)
        data_train.reset()
        data_train.is_first_epoch = True

    #mxboard setting
    mxlog_dir = args.config.get('common', 'mxboard_log_dir')
    summary_writer = SummaryWriter(mxlog_dir)

    while True:

        if n_epoch >= num_epoch:
            break
        loss_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):
            module.forward_backward(data_batch)
            module.update()
            # mxboard setting
            if (nbatch + 1) % show_every == 0:
                module.update_metric(loss_metric, data_batch.label)
            #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch+1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch)
                module.save_checkpoint(prefix=get_checkpoint_path(args)+"n_epoch"+str(n_epoch)+"n_batch", epoch=(int((nbatch+1)/save_checkpoint_every_n_batch)-1), save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            module.update_metric(eval_metric, data_batch.label)

        # mxboard setting
        val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value()
        log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'

        data_train.reset()
        data_train.is_first_epoch = False

        # mxboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value()
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states)

        n_epoch += 1

        lr_scheduler.learning_rate=learning_rate/learning_rate_annealing

    log.info('FINISH')
Пример #20
0
    def sample_normalize(self,
                         k_samples=1000,
                         overwrite=False,
                         noise_percent=0.4):
        """ Estimate the mean and std of the features from the training set
        Params:
            k_samples (int): Use this number of samples for estimation
        """
        log = LogUtil().getlogger()
        log.info("Calculating mean and std from samples")
        # if k_samples is negative then it goes through total dataset
        if k_samples < 0:
            audio_paths = self.train_audio_paths * 10

        # using sample
        else:
            k_samples = min(k_samples, len(self.train_audio_paths))
            samples = self.rng.sample(self.train_audio_paths, k_samples)
            audio_paths = samples
        # manager = Manager()
        # return_dict = manager.dict()
        # jobs = []
        # for threadIndex in range(cpu_count()):
        #     proc = Process(target=self.preprocess_sample_normalize,
        #                    args=(threadIndex, audio_paths, overwrite, noise_percent, return_dict))
        #     jobs.append(proc)
        #     proc.start()
        # for proc in jobs:
        #     proc.join()

        # return_dict = {}
        # self.preprocess_sample_normalize(1, audio_paths, overwrite, noise_percent, return_dict)

        # pool = Pool(processes=cpu_count())
        # results = []
        # for i, f in enumerate(audio_paths):
        #     result = pool.apply_async(spectrogram_from_file, args=(f,), kwds={"overwrite":overwrite, "noise_percent":noise_percent})
        #     results.append(result)
        # pool.close()
        # pool.join()
        # feat_dim = self.feat_dim
        # feat = np.zeros((1, feat_dim))
        # feat_squared = np.zeros((1, feat_dim))
        # count = 0
        # return_dict = {}
        # for data in results:
        #     next_feat = data.get()
        #     next_feat_squared = np.square(next_feat)
        #     feat_vertically_stacked = np.concatenate((feat, next_feat)).reshape(-1, feat_dim)
        #     feat = np.sum(feat_vertically_stacked, axis=0, keepdims=True)
        #     feat_squared_vertically_stacked = np.concatenate(
        #         (feat_squared, next_feat_squared)).reshape(-1, feat_dim)
        #     feat_squared = np.sum(feat_squared_vertically_stacked, axis=0, keepdims=True)
        #     count += float(next_feat.shape[0])
        # return_dict[1] = {'feat': feat, 'feat_squared': feat_squared, 'count': count}

        return_dict = {}
        with concurrent.futures.ThreadPoolExecutor(
                max_workers=cpu_count()) as executor:
            feat_dim = self.feat_dim
            feat = np.zeros((1, feat_dim))
            feat_squared = np.zeros((1, feat_dim))
            count = 0
            future_to_f = {
                executor.submit(spectrogram_from_file,
                                f,
                                overwrite=overwrite,
                                noise_percent=noise_percent): f
                for f in audio_paths
            }
            for future in concurrent.futures.as_completed(future_to_f):
                # for f, data in zip(audio_paths, executor.map(spectrogram_from_file, audio_paths, overwrite=overwrite, noise_percent=noise_percent)):
                f = future_to_f[future]
                try:
                    next_feat = future.result()
                    next_feat_squared = np.square(next_feat)
                    feat_vertically_stacked = np.concatenate(
                        (feat, next_feat)).reshape(-1, feat_dim)
                    feat = np.sum(feat_vertically_stacked,
                                  axis=0,
                                  keepdims=True)
                    feat_squared_vertically_stacked = np.concatenate(
                        (feat_squared,
                         next_feat_squared)).reshape(-1, feat_dim)
                    feat_squared = np.sum(feat_squared_vertically_stacked,
                                          axis=0,
                                          keepdims=True)
                    count += float(next_feat.shape[0])
                except Exception as exc:
                    log.info('%r generated an exception: %s' % (f, exc))
            return_dict[1] = {
                'feat': feat,
                'feat_squared': feat_squared,
                'count': count
            }

        feat = np.sum(np.vstack(
            [item['feat'] for item in return_dict.values()]),
                      axis=0)
        count = sum([item['count'] for item in return_dict.values()])
        feat_squared = np.sum(np.vstack(
            [item['feat_squared'] for item in return_dict.values()]),
                              axis=0)

        self.feats_mean = feat / float(count)
        self.feats_std = np.sqrt(feat_squared / float(count) -
                                 np.square(self.feats_mean))
        np.savetxt(
            generate_file_path(self.save_dir, self.model_name, 'feats_mean'),
            self.feats_mean)
        np.savetxt(
            generate_file_path(self.save_dir, self.model_name, 'feats_std'),
            self.feats_std)
        log.info("End calculating mean and std from samples")
Пример #21
0
def tgw_send(name, status,  sock, req_queue, config_file):
    LogUtil.get_instance(config_file, "tgw_send")
    LogUtil.info("tgw_send:"+name+" begin")
    
    config=configparser.ConfigParser()
    config.read(config_file)
    
    heartbeat_interval=config.getint("tgw", "heartbeat_interval")
    LogUtil.info("heartbeat_interval:"+str(heartbeat_interval))
    if heartbeat_interval==None or heartbeat_interval<=0 or heartbeat_interval>=1800:
        heartbeat_interval=60
        LogUtil.info("heartbeat_interval changed to:"+str(heartbeat_interval))
    
    send_heartbeat_interval=config.getint("tgw_send", "send_heartbeat_interval")
    LogUtil.info("send_heartbeat_interval:"+str(send_heartbeat_interval))
    if send_heartbeat_interval<=0 or send_heartbeat_interval>=heartbeat_interval:
        send_heartbeat_interval=heartbeat_interval
        LogUtil.info("send_heartbeat_interval changed to:"+str(send_heartbeat_interval))
    
    read_timeout=config.getint("tgw_send", "req_queue_timeout")
    LogUtil.info("read_timeout:"+str(read_timeout))
    if read_timeout<=0 or read_timeout>send_heartbeat_interval/2:
        read_timeout=send_heartbeat_interval/2
        LogUtil.info("read_timeout changed to:"+str(read_timeout))
    
    last_send_time=0
    heartbeat=msg.packHeartbeatMessage()
    
    while status.value==0:
        try:
            message=req_queue.get(block=True, timeout=read_timeout)
        except queue.Empty:
            LogUtil.debug("req_queue no data")
            this_time=time.time()
            if this_time-last_send_time>=send_heartbeat_interval:
                req_queue.put(heartbeat)
                last_send_time=this_time
        else:
            sock.sendall(message)
            LogUtil.info("tgw_send:"+binascii.hexlify(message).decode())
    LogUtil.info("tgw_send:"+name+" end")
                input.write(insid)
                input.write("\n")
        input.close()
    except:
        log.info("设置host失败,请用管理员身份运行")


if __name__ == "__main__":
    from sys import argv, exit

    a = QApplication(argv)

    server = QWebSocketServer("QWebChannel Standalone Server",
                              QWebSocketServer.NonSecureMode)
    if not server.listen(QHostAddress.LocalHost, int(web_sock_port)):
        log.info("监听端口 " + web_sock_port + " 失败,客户端已经打开...")
        exit(0)

    clientWrapper = WebSocketClientWrapper(server)

    channel = QWebChannel()
    clientWrapper.clientConnected.connect(channel.connectTo)

    dialog = Dialog()
    channel.registerObject("dialog", dialog)
    #========================初始化OMServer 连接=============================
    omClient = OMClient("omClient")
    omClient.start()
    # url=QUrl.fromLocalFile("./index.html");
    # url.setQuery("webChannelBaseUrl="+server.serverUrl().toString())
    # QDesktopServices.openUrl(url);
Пример #23
0
def md_responsor(name, status, resp_queue, config_file):
    LogUtil.get_instance(config_file, "md_responsor")
    LogUtil.info("md_responsor:"+name+" begin")

    config = configparser.ConfigParser()
    config.read(config_file)

    factory = msg.MessageProcessorFactory()
    db_file = config.get("message_config", "db_file")
    factory.load_from_db(db_file, [])

    read_timeout = config.getint("md_responsor", "resp_queue_timeout")
    LogUtil.info("read_timeout:"+str(read_timeout))
    if read_timeout <= 0 or read_timeout > 60:
        read_timeout = 60
        LogUtil.info("read_timeout changed to:"+str(read_timeout))

    pub_buf = config.get("md_responsor", "pub_buf")
    pub_msg = config.get("md_responsor", "pub_msg")
    pub_buf_addr = config.get("md_responsor", "pub_buf_addr")
    pub_msg_addr = config.get("md_responsor", "pub_msg_addr")

    LogUtil.debug("pub_buf:"+pub_buf+",pub_buf_addr:"+pub_buf_addr)
    LogUtil.debug("pub_msg:"+pub_msg+",pub_msg_addr:"+pub_msg_addr)

    if pub_buf:
        buf_ctx = zmq.Context()
        buf_sock = buf_ctx.socket(zmq.PUB)
        buf_sock.bind(pub_buf_addr)
    if pub_msg:
        msg_ctx = zmq.Context()
        msg_sock = msg_ctx.socket(zmq.PUB)
        msg_sock.bind(pub_msg_addr)

    left_buff = b''
    message_id = 0
    while status.value == 0:
        # TODO:refactor,try recv; and then process the buff
        # TODO:when processing buff, abstract the condition to next_message_ready()
        try:
            (buf_id, src_time, recv_buff) = resp_queue.get(block=True, timeout=read_timeout)
            if pub_buf:  # TODO:topic?
                buf_sock.send_pyobj((buf_id, src_time, recv_buff))
            left_buff = left_buff+recv_buff
            if len(left_buff) < Message.header_len+Message.header_len:
                continue
            (message_type, body_len) = msg.get_message_header(left_buff)
            next_message_len = body_len + Message.header_len+Message.footer_len
            while next_message_len <= len(left_buff):
                try:
                    message_processor = factory.build_message_processor(message_type)
                    message = message_processor.unpack(left_buff)
                    message_id = message_id+1
                    LogUtil.debug("message:"+message.toString())
                    if pub_msg:  # TODO:topic?
                        src_time = datetime.datetime.now()
                        msg_sock.send_pyobj((message_type, message_id, src_time, message.message_str))
                except KeyError:
                    LogUtil.error("unkown message type:"+str(message_type))
                except Exception as e:
                    LogUtil.error(e)
                    LogUtil.error("other error:"+traceback.print_exc())
                left_buff = left_buff[next_message_len:]
                if len(left_buff) < Message.header_len+Message.footer_len:
                    break
                else:
                    (message_type, body_len) = msg.get_message_header(left_buff)
                    next_message_len = body_len + Message.header_len+Message.footer_len
        except queue.Empty:
            LogUtil.debug("resp_queue no data")
#        except KeyError:
#            LogUtil.error("unkown message type:"+str(message_type))
        else:
            pass
    LogUtil.info("md_responsor:"+name+" end")
Пример #24
0
    def run(self):
        self.dialog.pushButtonMsgRecvStart.setEnabled(False)
        self.dialog.pushButtonMsgRecvStop.setEnabled(True)
        
        LogUtil.getLoggerInstance(self.config, "msg_recv")
        LogUtil.info("msg_recv:"+" begin")
        
        config=configparser.ConfigParser()
        config.read(self.config)

        write_msg=config.get("msg_recv", "write_msg")
        if write_msg:
            md_host=config.get("msg_recv", "md_host")
            md_database=config.get("msg_recv", "md_database")
            md_user=config.get("msg_recv", "md_user")
            md_password=config.get("msg_recv", "md_password")            
            md_conn=pgdb.connect(database=md_database, host=md_host, user=md_user, password=md_password)
            md_insert_msg_cursor=md_conn.cursor()
            md_insert_msg_sql='''insert into market_data_message(data_date,insert_time,message_type,message_id,message_content,src_time)
            values(%(data_date)s,%(insert_time)s,%(message_type)s,%(message_id)s,%(message_content)s,%(src_time)s)
            '''
            md_insert_msg_dict={'data_date':0, 'message_type':0, 'message_content':''}

        msg_sub_addr=config.get("msg_recv", "msg_sub_addr")
        msg_sub_topic=config.get("msg_recv", "msg_sub_topic")
        ctx=zmq.Context()
        sock=ctx.socket(zmq.SUB)
        sock.connect(msg_sub_addr)
        sock.setsockopt_string(zmq.SUBSCRIBE, msg_sub_topic)
        
        msgRecvCnt=0
        msgWriteCnt=0
        msgUpdateTime=None
        msgErrCnt=0        
        
        recvMsgStatus=RecvMsgStatus()
        msgStatus='Running'
        
        while not self.toStop:
            try:
                (message_type, message_id, src_time, recv_msg)=sock.recv_pyobj()
                msgRecvCnt=msgRecvCnt+1
                msgUpdateTime=datetime.datetime.now().strftime('%H:%M:%S.%f')
                if write_msg:
                    md_insert_msg_dict['data_date']=0
                    md_insert_msg_dict['insert_time']=datetime.datetime.now()
                    md_insert_msg_dict['message_type']=message_type
                    md_insert_msg_dict['message_id']=message_id
                    md_insert_msg_dict['message_content']=recv_msg
                    md_insert_msg_dict['src_time']=src_time
                    md_insert_msg_cursor.execute(md_insert_msg_sql, md_insert_msg_dict)
                    #TODO
                    md_conn.commit()
                    msgWriteCnt=msgWriteCnt+1
            except Exception as e:
                msgErrCnt=msgErrCnt+1
                LogUtil.error(e)
                LogUtil.error("MsgRecvErr:msgErrCnt="+str(msgErrCnt)+",msgRecvCnt="+str(msgRecvCnt)+",msgWriteCnt="+str(msgWriteCnt))
            finally:
                pass
            recvMsgStatus.msgRecvCnt=str(msgRecvCnt)
            recvMsgStatus.msgWriteCnt=str(msgWriteCnt)
            recvMsgStatus.msgErrCnt=str(msgErrCnt)
            recvMsgStatus.msgUpdateTime=msgUpdateTime
            recvMsgStatus.msgStatus=msgStatus
            self.msgStatusUpdated.emit(recvMsgStatus)
            
            LogUtil.debug("MsgRecvErr:msgErrCnt="+str(msgErrCnt)+",msgRecvCnt="+str(msgRecvCnt)+",msgWriteCnt="+str(msgWriteCnt))
            
        self.dialog.pushButtonMsgRecvStart.setEnabled(True)
        self.dialog.pushButtonMsgRecvStop.setEnabled(False)
Пример #25
0
    def run(self):
        self.dialog.pushButtonBufRecvStart.setEnabled(False)
        self.dialog.pushButtonBufRecvStop.setEnabled(True)
    
        LogUtil.getLoggerInstance(self.config, "buf_recv")
        LogUtil.info("buf_recv:"+" begin")
        
        config=configparser.ConfigParser()
        config.read(self.config)

        write_buff=config.get("buf_recv", "write_buff")
        if write_buff:
            md_host=config.get("buf_recv", "md_host")
            md_database=config.get("buf_recv", "md_database")
            md_user=config.get("buf_recv", "md_user")
            md_password=config.get("buf_recv", "md_password")            
            md_conn=pgdb.connect(database=md_database, host=md_host, user=md_user, password=md_password)
            md_insert_buf_cursor=md_conn.cursor()
            #md_insert_buf_sql='''insert into market_data_buff(data_date,insert_time,buff)
            #values(%(data_date)s,localtimestamp,%(buff)s)
            #'''
            md_insert_buf_sql='''insert into market_data_buff(data_date,buff_id,insert_time,buff,src_time)
            values(%(data_date)s,%(buf_id)s,%(insert_time)s,%(buff)s,%(src_time)s)
            '''
            md_insert_buf_dict={'data_date':0, 'buf_id':0, 'insert_time':None, 'buff':'', 'src_time':None}
        
        buf_sub_addr=config.get("buf_recv", "buf_sub_addr")
        buf_sub_topic=config.get("buf_recv", "buf_sub_topic")
        ctx=zmq.Context()
        sock=ctx.socket(zmq.SUB)
        sock.connect(buf_sub_addr)
        sock.setsockopt_string(zmq.SUBSCRIBE, buf_sub_topic)
        
        bufRecvCnt=0
        bufWriteCnt=0
        bufUpdateTime=None
        bufErrCnt=0     
        bufStatus='Running'   
        recvBufStatus=RecvBufStatus()        
        
        while not self.toStop:
            try:
                (buf_id, src_time, recv_buff)=sock.recv_pyobj()
                bufRecvCnt=bufRecvCnt+1
                bufUpdateTime=datetime.datetime.now().strftime('%H:%M:%S.%f')
                if write_buff:
                    md_insert_buf_dict['data_date']=0
                    md_insert_buf_dict['insert_time']=datetime.datetime.now()
                    md_insert_buf_dict['buf_id']=buf_id
                    md_insert_buf_dict['buff']=binascii.hexlify(recv_buff).decode()
                    #msg=recv_buff.decode()
                    md_insert_buf_dict['src_time']=src_time
                    md_insert_buf_cursor.execute(md_insert_buf_sql, md_insert_buf_dict)
                    #TODO
                    md_conn.commit()
                    bufWriteCnt=bufWriteCnt+1
                    #self.dialog.bufWriteCnt.setText(str(bufWriteCnt))
            except Exception as e:
                bufErrCnt=bufErrCnt+1
                #self.dialog.bufErrCnt.setText(str(bufErrCnt))
                LogUtil.error(e)
                LogUtil.error("BufRecvErr:bufErrCnt="+str(bufErrCnt)+",bufRecvCnt="+str(bufRecvCnt)+",bufWriteCnt="+str(bufWriteCnt))
            finally:
                pass
            #TODO    
            recvBufStatus.bufRecvCnt=str(bufRecvCnt)
            recvBufStatus.bufWriteCnt=str(bufWriteCnt)
            recvBufStatus.bufErrCnt=str(bufErrCnt)
            recvBufStatus.bufUpdateTime=bufUpdateTime
            recvBufStatus.bufStatus=bufStatus
            self.bufStatusUpdated.emit(recvBufStatus)
            
            LogUtil.debug("BufRecvErr:bufErrCnt="+str(bufErrCnt)+",bufRecvCnt="+str(bufRecvCnt)+",bufWriteCnt="+str(bufWriteCnt))
            
        self.dialog.pushButtonBufRecvStart.setEnabled(True)
        self.dialog.pushButtonBufRecvStop.setEnabled(False)
Пример #26
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint(
        'common', 'save_checkpoint_every_n_epoch')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            seq_length=seq_len)

    optimizer = args.config.get('train', 'optimizer')
    momentum = args.config.getfloat('train', 'momentum')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    lr_scheduler = SimpleLRScheduler(learning_rate,
                                     momentum=momentum,
                                     optimizer=optimizer)

    n_epoch = begin_epoch
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('train', 'clip_gradient')
    weight_decay = args.config.getfloat('train', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train',
                                                   'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')

    if clip_gradient == 0:
        clip_gradient = None

    module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0:
        module.init_params(initializer=get_initializer(args))

    def reset_optimizer():
        module.init_optimizer(kvstore='device',
                              optimizer=args.config.get('train', 'optimizer'),
                              optimizer_params={
                                  'clip_gradient': clip_gradient,
                                  'wd': weight_decay
                              },
                              force_init=True)

    reset_optimizer()

    while True:

        if n_epoch >= num_epoch:
            break

        eval_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):

            if data_batch.effective_sample_count is not None:
                lr_scheduler.effective_sample_count = data_batch.effective_sample_count

            module.forward_backward(data_batch)
            module.update()
            if (nbatch + 1) % show_every == 0:
                module.update_metric(eval_metric, data_batch.label)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        for nbatch, data_batch in enumerate(data_val):
            module.update_metric(eval_metric, data_batch.label)
        #module.score(eval_data=data_val, num_batch=None, eval_metric=eval_metric, reset=True)

        data_train.reset()
        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args),
                                   epoch=n_epoch,
                                   save_optimizer_states=save_optimizer_states)

        n_epoch += 1

    log.info('FINISH')
Пример #27
0
def load_data(args):
    mode = args.config.get('common', 'mode')
    if mode not in ['train', 'predict', 'load']:
        raise Exception('mode must be the one of the followings - train,predict,load')
    batch_size = args.config.getint('common', 'batch_size')

    whcs = WHCS()
    whcs.width = args.config.getint('data', 'width')
    whcs.height = args.config.getint('data', 'height')
    whcs.channel = args.config.getint('data', 'channel')
    whcs.stride = args.config.getint('data', 'stride')
    save_dir = 'checkpoints'
    model_name = args.config.get('common', 'prefix')
    is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')
    overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files')
    overwrite_bi_graphemes_dictionary = args.config.getboolean('train', 'overwrite_bi_graphemes_dictionary')
    max_duration = args.config.getfloat('data', 'max_duration')
    language = args.config.get('data', 'language')

    log = LogUtil().getlogger()
    labelUtil = LabelUtil.getInstance()
    if mode == "train" or mode == "load":
        data_json = args.config.get('data', 'train_json')
        val_json = args.config.get('data', 'val_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(data_json, max_duration=max_duration)
        datagen.load_validation_data(val_json, max_duration=max_duration)
        if is_bi_graphemes:
            if not os.path.isfile("resources/unicodemap_en_baidu_bi_graphemes.csv") or overwrite_bi_graphemes_dictionary:
                load_labelutil(labelUtil=labelUtil, is_bi_graphemes=False, language=language)
                generate_bi_graphemes_dictionary(datagen.train_texts+datagen.val_texts)
        load_labelutil(labelUtil=labelUtil, is_bi_graphemes=is_bi_graphemes, language=language)
        args.config.set('arch', 'n_classes', str(labelUtil.get_count()))

        if mode == "train":
            if overwrite_meta_files:
                log.info("Generate mean and std from samples")
                normalize_target_k = args.config.getint('train', 'normalize_target_k')
                datagen.sample_normalize(normalize_target_k, True)
            else:
                log.info("Read mean and std from meta files")
                datagen.get_meta_from_file(
                    np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
                    np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
        elif mode == "load":
            # get feat_mean and feat_std to normalize dataset
            datagen.get_meta_from_file(
                np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
                np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

    elif mode == 'predict':
        test_json = args.config.get('data', 'test_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(test_json, max_duration=max_duration)
        labelutil = load_labelutil(labelUtil, is_bi_graphemes, language="en")
        args.config.set('arch', 'n_classes', str(labelUtil.get_count()))
        datagen.get_meta_from_file(
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

    is_batchnorm = args.config.getboolean('arch', 'is_batchnorm')
    if batch_size == 1 and is_batchnorm and (mode == 'train' or mode == 'load'):
        raise Warning('batch size 1 is too small for is_batchnorm')

    # sort file paths by its duration in ascending order to implement sortaGrad
    if mode == "train" or mode == "load":
        max_t_count = datagen.get_max_seq_length(partition="train")
        max_label_length = \
            datagen.get_max_label_length(partition="train", is_bi_graphemes=is_bi_graphemes)
    elif mode == "predict":
        max_t_count = datagen.get_max_seq_length(partition="test")
        max_label_length = \
            datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes)

    args.config.set('arch', 'max_t_count', str(max_t_count))
    args.config.set('arch', 'max_label_length', str(max_label_length))
    from importlib import import_module
    prepare_data_template = import_module(args.config.get('arch', 'arch_file'))
    init_states = prepare_data_template.prepare_data(args)
    sort_by_duration = (mode == "train")
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')
    save_feature_as_csvfile = args.config.getboolean('train', 'save_feature_as_csvfile')
    if is_bucketing:
        buckets = json.loads(args.config.get('arch', 'buckets'))
        data_loaded = BucketSTTIter(partition="train",
                                    count=datagen.count,
                                    datagen=datagen,
                                    batch_size=batch_size,
                                    num_label=max_label_length,
                                    init_states=init_states,
                                    seq_length=max_t_count,
                                    width=whcs.width,
                                    height=whcs.height,
                                    sort_by_duration=sort_by_duration,
                                    is_bi_graphemes=is_bi_graphemes,
                                    buckets=buckets,
                                    save_feature_as_csvfile=save_feature_as_csvfile)
    else:
        data_loaded = STTIter(partition="train",
                              count=datagen.count,
                              datagen=datagen,
                              batch_size=batch_size,
                              num_label=max_label_length,
                              init_states=init_states,
                              seq_length=max_t_count,
                              width=whcs.width,
                              height=whcs.height,
                              sort_by_duration=sort_by_duration,
                              is_bi_graphemes=is_bi_graphemes,
                              save_feature_as_csvfile=save_feature_as_csvfile)

    if mode == 'train' or mode == 'load':
        if is_bucketing:
            validation_loaded = BucketSTTIter(partition="validation",
                                              count=datagen.val_count,
                                              datagen=datagen,
                                              batch_size=batch_size,
                                              num_label=max_label_length,
                                              init_states=init_states,
                                              seq_length=max_t_count,
                                              width=whcs.width,
                                              height=whcs.height,
                                              sort_by_duration=False,
                                              is_bi_graphemes=is_bi_graphemes,
                                              buckets=buckets,
                                              save_feature_as_csvfile=save_feature_as_csvfile)
        else:
            validation_loaded = STTIter(partition="validation",
                                        count=datagen.val_count,
                                        datagen=datagen,
                                        batch_size=batch_size,
                                        num_label=max_label_length,
                                        init_states=init_states,
                                        seq_length=max_t_count,
                                        width=whcs.width,
                                        height=whcs.height,
                                        sort_by_duration=False,
                                        is_bi_graphemes=is_bi_graphemes,
                                        save_feature_as_csvfile=save_feature_as_csvfile)
        return data_loaded, validation_loaded, args
    elif mode == 'predict':
        return data_loaded, args
Пример #28
0
    def sample_normalize_fbank(self,
                               k_samples=1000,
                               overwrite=False,
                               noise_percent=0.4):
        log = LogUtil().getlogger()
        log.info("Calculating mean and std from samples")
        # if k_samples is negative then it goes through total dataset
        if k_samples < 0:
            audio_paths = self.train_audio_paths * 10

        # using sample
        else:
            k_samples = min(k_samples, len(self.train_audio_paths))
            samples = self.rng.sample(self.train_audio_paths, k_samples)
            audio_paths = samples
        return_dict = {}
        with concurrent.futures.ThreadPoolExecutor(
                max_workers=cpu_count()) as executor:
            feat_dim = 3 * 41
            feat = np.zeros((1, feat_dim))
            feat_squared = np.zeros((1, feat_dim))
            count = 0
            future_to_f = {
                executor.submit(fbank_from_file,
                                f,
                                overwrite=overwrite,
                                noise_percent=noise_percent): f
                for f in audio_paths
            }
            for future in concurrent.futures.as_completed(future_to_f):
                # for f, data in zip(audio_paths, executor.map(spectrogram_from_file, audio_paths, overwrite=overwrite, noise_percent=noise_percent)):
                f = future_to_f[future]
                try:
                    next_feat = future.result().swapaxes(0, 1).reshape(
                        -1, feat_dim)
                    next_feat_squared = np.square(next_feat)
                    feat_vertically_stacked = np.concatenate(
                        (feat, next_feat)).reshape(-1, feat_dim)
                    feat = np.sum(feat_vertically_stacked,
                                  axis=0,
                                  keepdims=True)
                    feat_squared_vertically_stacked = np.concatenate(
                        (feat_squared,
                         next_feat_squared)).reshape(-1, feat_dim)
                    feat_squared = np.sum(feat_squared_vertically_stacked,
                                          axis=0,
                                          keepdims=True)
                    count += float(next_feat.shape[0])
                except Exception as exc:
                    log.info('%r generated an exception: %s' % (f, exc))
            return_dict[1] = {
                'feat': feat,
                'feat_squared': feat_squared,
                'count': count
            }

        feat = np.sum(np.vstack(
            [item['feat'] for item in return_dict.values()]),
                      axis=0)
        count = sum([item['count'] for item in return_dict.values()])
        feat_squared = np.sum(np.vstack(
            [item['feat_squared'] for item in return_dict.values()]),
                              axis=0)

        self.feats_mean = feat / float(count)
        self.feats_std = np.sqrt(feat_squared / float(count) -
                                 np.square(self.feats_mean))
        np.savetxt(
            generate_file_path(self.save_dir, self.model_name, 'feats_mean'),
            self.feats_mean)
        np.savetxt(
            generate_file_path(self.save_dir, self.model_name, 'feats_std'),
            self.feats_std)
        log.info("End calculating mean and std from samples")
Пример #29
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    #seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint(
        'common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint(
        'common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean(
        'train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean(
        'train', 'enable_logging_validation_metric')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            is_logging=enable_logging_validation_metric,
                            is_epoch_end=True)
    # mxboard setting
    loss_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            is_logging=enable_logging_train_metric,
                            is_epoch_end=False)

    optimizer = args.config.get('optimizer', 'optimizer')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_annealing = args.config.getfloat('train',
                                                   'learning_rate_annealing')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('optimizer', 'clip_gradient')
    weight_decay = args.config.getfloat('optimizer', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train',
                                                   'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    optimizer_params_dictionary = json.loads(
        args.config.get('optimizer', 'optimizer_params_dictionary'))
    kvstore_option = args.config.get('common', 'kvstore_option')
    n_epoch = begin_epoch
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')

    if clip_gradient == 0:
        clip_gradient = None
    if is_bucketing and mode == 'load':
        model_file = args.config.get('common', 'model_file')
        model_name = os.path.splitext(model_file)[0]
        model_num_epoch = int(model_name[-4:])

        model_path = 'checkpoints/' + str(model_name[:-5])
        symbol, data_names, label_names = module(1600)
        model = STTBucketingModule(
            sym_gen=module,
            default_bucket_key=data_train.default_bucket_key,
            context=contexts)
        data_train.reset()

        model.bind(data_shapes=data_train.provide_data,
                   label_shapes=data_train.provide_label,
                   for_training=True)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            model_path, model_num_epoch)
        model.set_params(arg_params, aux_params)
        module = model
    else:
        module.bind(data_shapes=data_train.provide_data,
                    label_shapes=data_train.provide_label,
                    for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))

    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    def reset_optimizer(force_init=False):
        optimizer_params = {
            'lr_scheduler': lr_scheduler,
            'clip_gradient': clip_gradient,
            'wd': weight_decay
        }
        optimizer_params.update(optimizer_params_dictionary)
        module.init_optimizer(kvstore=kvstore_option,
                              optimizer=optimizer,
                              optimizer_params=optimizer_params,
                              force_init=force_init)

    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)
        data_train.reset()
        data_train.is_first_epoch = True

    #mxboard setting
    mxlog_dir = args.config.get('common', 'mxboard_log_dir')
    summary_writer = SummaryWriter(mxlog_dir)

    while True:

        if n_epoch >= num_epoch:
            break
        loss_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):
            module.forward_backward(data_batch)
            module.update()
            # mxboard setting
            if (nbatch + 1) % show_every == 0:
                module.update_metric(loss_metric, data_batch.label)
            #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch + 1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch,
                         nbatch)
                module.save_checkpoint(
                    prefix=get_checkpoint_path(args) + "n_epoch" +
                    str(n_epoch) + "n_batch",
                    epoch=(int(
                        (nbatch + 1) / save_checkpoint_every_n_batch) - 1),
                    save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            module.update_metric(eval_metric, data_batch.label)

        # mxboard setting
        val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value()
        log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer,
                 int(val_n_label - val_l_dist), val_n_label)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'

        data_train.reset()
        data_train.is_first_epoch = False

        # mxboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value(
        )
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args),
                                   epoch=n_epoch,
                                   save_optimizer_states=save_optimizer_states)

        n_epoch += 1

        lr_scheduler.learning_rate = learning_rate / learning_rate_annealing

    log.info('FINISH')
    def __init__(self, count, datagen, batch_size, num_label, init_states, seq_length, width, height,
                 sort_by_duration=True,
                 is_bi_graphemes=False,
                 language="zh",
                 zh_type="zi",
                 partition="train",
                 buckets=[],
                 save_feature_as_csvfile=False,
                 num_parts=1,
                 part_index=0,
                 noise_percent=0.4,
                 fbank=False
                 ):
        super(BucketSTTIter, self).__init__()

        self.maxLabelLength = num_label
        # global param
        self.batch_size = batch_size
        self.count = count
        self.num_label = num_label
        self.init_states = init_states
        self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states]
        self.width = width
        self.height = height
        self.datagen = datagen
        self.label = None
        self.is_bi_graphemes = is_bi_graphemes
        self.language = language
        self.zh_type = zh_type
        self.num_parts = num_parts
        self.part_index = part_index
        self.noise_percent = noise_percent
        self.fbank = fbank
        # self.partition = datagen.partition
        if partition == 'train':
            durations = datagen.train_durations
            audio_paths = datagen.train_audio_paths
            texts = datagen.train_texts
        elif partition == 'validation':
            durations = datagen.val_durations
            audio_paths = datagen.val_audio_paths
            texts = datagen.val_texts
        elif partition == 'test':
            durations = datagen.test_durations
            audio_paths = datagen.test_audio_paths
            texts = datagen.test_texts
        else:
            raise Exception("Invalid partition to load metadata. "
                            "Must be train/validation/test")
        log = LogUtil().getlogger()
        # if sortagrad
        if sort_by_duration:
            durations, audio_paths, texts = datagen.sort_by_duration(durations,
                                                                     audio_paths,
                                                                     texts)
        else:
            durations = durations
            audio_paths = audio_paths
            texts = texts
        self.trainDataList = list(zip(durations, audio_paths, texts))

        # self.trainDataList = [d for index, d in enumerate(zip(durations, audio_paths, texts)) if index % self.num_parts == self.part_index]
        # log.info("partition: %s, num_works: %d, part_index: %d 's data size is %d of all size is %d" %
        #          (partition, self.num_parts, self.part_index, len(self.trainDataList), len(durations)))
        self.trainDataIter = iter(self.trainDataList)
        self.is_first_epoch = True

        data_lengths = [int(d * 100) for d in durations]
        if len(buckets) == 0:
            buckets = [i for i, j in enumerate(np.bincount(data_lengths))
                       if j >= batch_size]
        if len(buckets) == 0:
            raise Exception(
                'There is no valid buckets. It may occured by large batch_size for each buckets. max bincount:%d batch_size:%d' % (
                    max(np.bincount(data_lengths)), batch_size))
        buckets.sort()
        ndiscard = 0
        self.data = [[] for _ in buckets]
        for i, sent in enumerate(data_lengths):
            buck = bisect.bisect_left(buckets, sent)
            if buck == len(buckets):
                ndiscard += 1
                continue
            self.data[buck].append(self.trainDataList[i])
        if ndiscard != 0:
            print("WARNING: discarded %d sentences longer than the largest bucket." % ndiscard)
        # self.num_parts = 3 debug
        # self.part_index = 2
        for index_buck, buck in enumerate(self.data):
            self.data[index_buck] = [d for index_d, d in enumerate(
                self.data[index_buck][:len(self.data[index_buck]) // self.num_parts * self.num_parts]) if
                                     index_d % self.num_parts == self.part_index]
            log.info("partition: %s, num_works: %d, part_index: %d %d's data size is %d " %
                     (partition, self.num_parts, self.part_index, index_buck, len(self.data[index_buck])))
        self.buckets = buckets
        self.nddata = []
        self.ndlabel = []
        self.default_bucket_key = max(buckets)

        self.idx = []
        for i, buck in enumerate(self.data):
            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
        self.curr_idx = 0

        if not self.fbank:
            self.provide_data = [('data', (self.batch_size, self.default_bucket_key, width * height))] + init_states
        else:
            self.provide_data = [('data', (self.batch_size, 3, self.default_bucket_key, 41))] + init_states
        self.provide_label = [('label', (self.batch_size, self.maxLabelLength))]
        self.save_feature_as_csvfile = save_feature_as_csvfile
Пример #31
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len)

    optimizer = args.config.get('train', 'optimizer')
    momentum = args.config.getfloat('train', 'momentum')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    lr_scheduler = SimpleLRScheduler(learning_rate, momentum=momentum, optimizer=optimizer)

    n_epoch = begin_epoch
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('train', 'clip_gradient')
    weight_decay = args.config.getfloat('train', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')

    if clip_gradient == 0:
        clip_gradient = None

    module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0:
        module.init_params(initializer=get_initializer(args))

    def reset_optimizer():
        module.init_optimizer(kvstore='device',
                              optimizer=args.config.get('train', 'optimizer'),
                              optimizer_params={'clip_gradient': clip_gradient,
                                                'wd': weight_decay},
                              force_init=True)

    reset_optimizer()

    while True:

        if n_epoch >= num_epoch:
            break

        eval_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):

            if data_batch.effective_sample_count is not None:
                lr_scheduler.effective_sample_count = data_batch.effective_sample_count

            module.forward_backward(data_batch)
            module.update()
            if (nbatch+1) % show_every == 0:
                module.update_metric(eval_metric, data_batch.label)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        for nbatch, data_batch in enumerate(data_val):
            module.update_metric(eval_metric, data_batch.label)
        #module.score(eval_data=data_val, num_batch=None, eval_metric=eval_metric, reset=True)

        data_train.reset()
        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states)

        n_epoch += 1

    log.info('FINISH')
Пример #32
0
    def update(self, labels, preds):
        check_label_shapes(labels, preds)
        if self.is_logging:
            log = LogUtil().getlogger()
            labelUtil = LabelUtil()
        self.batch_loss = 0.
        shouldPrint = True
        host_name = socket.gethostname()
        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()
            seq_length = len(pred) / int(
                int(self.batch_size) / int(self.num_gpu))
            # sess = tf.Session()
            for i in range(int(int(self.batch_size) / int(self.num_gpu))):
                l = remove_blank(label[i])
                # p = []
                probs = []
                for k in range(int(seq_length)):
                    # p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i]))
                    probs.append(
                        pred[k * int(int(self.batch_size) / int(self.num_gpu))
                             + i])
                # p = pred_best(p)
                probs = np.array(probs)
                st = time.time()
                beam_size = 20
                results = ctc_beam_decode(self.scorer, beam_size,
                                          labelUtil.byList, probs)
                log.info("decode by ctc_beam cost %.2f result: %s" %
                         (time.time() - st, "\n".join(results)))

                res_str1 = ctc_greedy_decode(probs, labelUtil.byList)
                log.info("decode by pred_best: %s" % res_str1)

                # max_time_steps = int(seq_length)
                # input_log_prob_matrix_0 = np.log(probs)  # + 2.0
                #
                # # len max_time_steps array of batch_size x depth matrices
                # inputs = ([
                #   input_log_prob_matrix_0[t, :][np.newaxis, :] for t in range(max_time_steps)]
                # )
                #
                # inputs_t = [ops.convert_to_tensor(x) for x in inputs]
                # inputs_t = array_ops.stack(inputs_t)
                #
                # st = time.time()
                # # run CTC beam search decoder in tensorflow
                # decoded, log_probabilities = tf.nn.ctc_beam_search_decoder(inputs_t,
                #                                                            [max_time_steps],
                #                                                            beam_width=10,
                #                                                            top_paths=3,
                #                                                            merge_repeated=False)
                # tf_decoded, tf_log_probs = sess.run([decoded, log_probabilities])
                # st1 = time.time() - st
                # for index in range(3):
                #   tf_result = ''.join([labelUtil.byIndex.get(i + 1, ' ') for i in tf_decoded[index].values])
                #   print("%.2f elpse %.2f, %s" % (tf_log_probs[0][index], st1, tf_result))
                l_distance = editdistance.eval(
                    labelUtil.convert_num_to_word(l).split(" "), res_str1)
                # l_distance_beam = editdistance.eval(labelUtil.convert_num_to_word(l).split(" "), beam_result[0][1])
                l_distance_beam_cpp = editdistance.eval(
                    labelUtil.convert_num_to_word(l).split(" "), results[0])
                self.total_n_label += len(l)
                # self.total_l_dist_beam += l_distance_beam
                self.total_l_dist_beam_cpp += l_distance_beam_cpp
                self.total_l_dist += l_distance
                this_cer = float(l_distance) / float(len(l))
                if self.is_logging:
                    # log.info("%s label: %s " % (host_name, labelUtil.convert_num_to_word(l)))
                    # log.info("%s pred : %s , cer: %f (distance: %d/ label length: %d)" % (
                    #     host_name, labelUtil.convert_num_to_word(p), this_cer, l_distance, len(l)))
                    log.info("%s label: %s " %
                             (host_name, labelUtil.convert_num_to_word(l)))
                    log.info(
                        "%s pred : %s , cer: %f (distance: %d/ label length: %d)"
                        % (host_name, res_str1, this_cer, l_distance, len(l)))
                    # log.info("%s predb: %s , cer: %f (distance: %d/ label length: %d)" % (
                    #     host_name, " ".join(beam_result[0][1]), float(l_distance_beam) / len(l), l_distance_beam,
                    #     len(l)))
                    log.info(
                        "%s predc: %s , cer: %f (distance: %d/ label length: %d)"
                        % (host_name, " ".join(
                            results[0]), float(l_distance_beam_cpp) / len(l),
                           l_distance_beam_cpp, len(l)))
                self.total_ctc_loss += self.batch_loss
                self.placeholder = res_str1 + "\n" + "\n".join(results)
Пример #33
0
def load_data(args):
    mode = args.config.get('common', 'mode')
    if mode not in ['train', 'predict', 'load']:
        raise Exception(
            'mode must be the one of the followings - train,predict,load')
    batch_size = args.config.getint('common', 'batch_size')

    whcs = WHCS()
    whcs.width = args.config.getint('data', 'width')
    whcs.height = args.config.getint('data', 'height')
    whcs.channel = args.config.getint('data', 'channel')
    whcs.stride = args.config.getint('data', 'stride')
    save_dir = 'checkpoints'
    model_name = args.config.get('common', 'prefix')
    is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')
    overwrite_meta_files = args.config.getboolean('train',
                                                  'overwrite_meta_files')
    overwrite_bi_graphemes_dictionary = args.config.getboolean(
        'train', 'overwrite_bi_graphemes_dictionary')
    max_duration = args.config.getfloat('data', 'max_duration')
    language = args.config.get('data', 'language')

    log = LogUtil().getlogger()
    labelUtil = LabelUtil.getInstance()
    if mode == "train" or mode == "load":
        data_json = args.config.get('data', 'train_json')
        val_json = args.config.get('data', 'val_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(data_json, max_duration=max_duration)
        datagen.load_validation_data(val_json, max_duration=max_duration)
        if is_bi_graphemes:
            if not os.path.isfile(
                    "resources/unicodemap_en_baidu_bi_graphemes.csv"
            ) or overwrite_bi_graphemes_dictionary:
                load_labelutil(labelUtil=labelUtil,
                               is_bi_graphemes=False,
                               language=language)
                generate_bi_graphemes_dictionary(datagen.train_texts +
                                                 datagen.val_texts)
        load_labelutil(labelUtil=labelUtil,
                       is_bi_graphemes=is_bi_graphemes,
                       language=language)
        args.config.set('arch', 'n_classes', str(labelUtil.get_count()))

        if mode == "train":
            if overwrite_meta_files:
                log.info("Generate mean and std from samples")
                normalize_target_k = args.config.getint(
                    'train', 'normalize_target_k')
                datagen.sample_normalize(normalize_target_k, True)
            else:
                log.info("Read mean and std from meta files")
                datagen.get_meta_from_file(
                    np.loadtxt(
                        generate_file_path(save_dir, model_name,
                                           'feats_mean')),
                    np.loadtxt(
                        generate_file_path(save_dir, model_name, 'feats_std')))
        elif mode == "load":
            # get feat_mean and feat_std to normalize dataset
            datagen.get_meta_from_file(
                np.loadtxt(
                    generate_file_path(save_dir, model_name, 'feats_mean')),
                np.loadtxt(
                    generate_file_path(save_dir, model_name, 'feats_std')))

    elif mode == 'predict':
        test_json = args.config.get('data', 'test_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(test_json, max_duration=max_duration)
        labelutil = load_labelutil(labelUtil, is_bi_graphemes, language="en")
        args.config.set('arch', 'n_classes', str(labelUtil.get_count()))
        datagen.get_meta_from_file(
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

    is_batchnorm = args.config.getboolean('arch', 'is_batchnorm')
    if batch_size == 1 and is_batchnorm and (mode == 'train'
                                             or mode == 'load'):
        raise Warning('batch size 1 is too small for is_batchnorm')

    # sort file paths by its duration in ascending order to implement sortaGrad
    if mode == "train" or mode == "load":
        max_t_count = datagen.get_max_seq_length(partition="train")
        max_label_length = \
            datagen.get_max_label_length(partition="train", is_bi_graphemes=is_bi_graphemes)
    elif mode == "predict":
        max_t_count = datagen.get_max_seq_length(partition="test")
        max_label_length = \
            datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes)

    args.config.set('arch', 'max_t_count', str(max_t_count))
    args.config.set('arch', 'max_label_length', str(max_label_length))
    from importlib import import_module
    prepare_data_template = import_module(args.config.get('arch', 'arch_file'))
    init_states = prepare_data_template.prepare_data(args)
    sort_by_duration = (mode == "train")
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')
    save_feature_as_csvfile = args.config.getboolean(
        'train', 'save_feature_as_csvfile')
    if is_bucketing:
        buckets = json.loads(args.config.get('arch', 'buckets'))
        data_loaded = BucketSTTIter(
            partition="train",
            count=datagen.count,
            datagen=datagen,
            batch_size=batch_size,
            num_label=max_label_length,
            init_states=init_states,
            seq_length=max_t_count,
            width=whcs.width,
            height=whcs.height,
            sort_by_duration=sort_by_duration,
            is_bi_graphemes=is_bi_graphemes,
            buckets=buckets,
            save_feature_as_csvfile=save_feature_as_csvfile)
    else:
        data_loaded = STTIter(partition="train",
                              count=datagen.count,
                              datagen=datagen,
                              batch_size=batch_size,
                              num_label=max_label_length,
                              init_states=init_states,
                              seq_length=max_t_count,
                              width=whcs.width,
                              height=whcs.height,
                              sort_by_duration=sort_by_duration,
                              is_bi_graphemes=is_bi_graphemes,
                              save_feature_as_csvfile=save_feature_as_csvfile)

    if mode == 'train' or mode == 'load':
        if is_bucketing:
            validation_loaded = BucketSTTIter(
                partition="validation",
                count=datagen.val_count,
                datagen=datagen,
                batch_size=batch_size,
                num_label=max_label_length,
                init_states=init_states,
                seq_length=max_t_count,
                width=whcs.width,
                height=whcs.height,
                sort_by_duration=False,
                is_bi_graphemes=is_bi_graphemes,
                buckets=buckets,
                save_feature_as_csvfile=save_feature_as_csvfile)
        else:
            validation_loaded = STTIter(
                partition="validation",
                count=datagen.val_count,
                datagen=datagen,
                batch_size=batch_size,
                num_label=max_label_length,
                init_states=init_states,
                seq_length=max_t_count,
                width=whcs.width,
                height=whcs.height,
                sort_by_duration=False,
                is_bi_graphemes=is_bi_graphemes,
                save_feature_as_csvfile=save_feature_as_csvfile)
        return data_loaded, validation_loaded, args
    elif mode == 'predict':
        return data_loaded, args
Пример #34
0
class Net(object):
    def __init__(self):
        if len(sys.argv) <= 1:
            raise Exception('cfg file path must be provided. ' +
                            'ex)python main.py --configfile examplecfg.cfg')
        self.args = parse_args(sys.argv[1])
        # set parameters from cfg file
        # give random seed
        self.random_seed = self.args.config.getint('common', 'random_seed')
        self.mx_random_seed = self.args.config.getint('common',
                                                      'mx_random_seed')
        # random seed for shuffling data list
        if self.random_seed != -1:
            np.random.seed(self.random_seed)
        # set mx.random.seed to give seed for parameter initialization
        if self.mx_random_seed != -1:
            mx.random.seed(self.mx_random_seed)
        else:
            mx.random.seed(hash(datetime.now()))
        # set log file name
        self.log_filename = self.args.config.get('common', 'log_filename')
        self.log = LogUtil(filename=self.log_filename).getlogger()

        # set parameters from data section(common)
        self.mode = self.args.config.get('common', 'mode')

        # get meta file where character to number conversions are defined

        self.contexts = parse_contexts(self.args)
        self.num_gpu = len(self.contexts)
        self.batch_size = self.args.config.getint('common', 'batch_size')
        # check the number of gpus is positive divisor of the batch size for data parallel
        self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm')
        self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing')

        # log current config
        self.config_logger = ConfigLogger(self.log)
        self.config_logger(self.args.config)

        default_bucket_key = 1600
        self.args.config.set('arch', 'max_t_count', str(default_bucket_key))
        self.args.config.set('arch', 'max_label_length', str(100))
        self.labelUtil = LabelUtil()
        is_bi_graphemes = self.args.config.getboolean('common',
                                                      'is_bi_graphemes')
        load_labelutil(self.labelUtil, is_bi_graphemes, language="zh")
        self.args.config.set('arch', 'n_classes',
                             str(self.labelUtil.get_count()))
        self.max_t_count = self.args.config.getint('arch', 'max_t_count')
        # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states')

        # load model
        self.model_loaded, self.model_num_epoch, self.model_path = load_model(
            self.args)

        self.model = STTBucketingModule(sym_gen=self.model_loaded,
                                        default_bucket_key=default_bucket_key,
                                        context=self.contexts)

        from importlib import import_module
        prepare_data_template = import_module(
            self.args.config.get('arch', 'arch_file'))
        init_states = prepare_data_template.prepare_data(self.args)
        width = self.args.config.getint('data', 'width')
        height = self.args.config.getint('data', 'height')
        self.model.bind(data_shapes=[
            ('data', (self.batch_size, default_bucket_key, width * height))
        ] + init_states,
                        label_shapes=[
                            ('label',
                             (self.batch_size,
                              self.args.config.getint('arch',
                                                      'max_label_length')))
                        ],
                        for_training=True)

        _, self.arg_params, self.aux_params = mx.model.load_checkpoint(
            self.model_path, self.model_num_epoch)
        self.model.set_params(self.arg_params,
                              self.aux_params,
                              allow_extra=True,
                              allow_missing=True)

        try:
            from swig_wrapper import Scorer

            vocab_list = [
                chars.encode("utf-8") for chars in self.labelUtil.byList
            ]
            self.log.info("vacab_list len is %d" % len(vocab_list))
            _ext_scorer = Scorer(0.26, 0.1,
                                 self.args.config.get('common', 'kenlm'),
                                 vocab_list)
            lm_char_based = _ext_scorer.is_character_based()
            lm_max_order = _ext_scorer.get_max_order()
            lm_dict_size = _ext_scorer.get_dict_size()
            self.log.info("language model: "
                          "is_character_based = %d," % lm_char_based +
                          " max_order = %d," % lm_max_order +
                          " dict_size = %d" % lm_dict_size)
            self.scorer = _ext_scorer
            # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True,
            #                                  scorer=_ext_scorer)
        except ImportError:
            import kenlm
            km = kenlm.Model(self.args.config.get('common', 'kenlm'))
            self.scorer = km.score
            # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True,
            #                                  scorer=km.score)

    def getTrans(self, wav_file):
        self.data_train, self.args = load_data(self.args, wav_file)

        # self.model.set_params(self.arg_params, self.aux_params)

        # backward_t358_l1_batchnorm_moving_var

        model_loaded = self.model
        max_t_count = self.args.config.getint('arch', 'max_t_count')

        for nbatch, data_batch in enumerate(self.data_train):
            st = time.time()
            model_loaded.forward(data_batch, is_train=False)
            probs = model_loaded.get_outputs()[0].asnumpy()
            self.log.info("forward cost %.2f" % (time.time() - st))
            beam_size = 5
            from stt_metric import ctc_beam_decode
            st = time.time()
            results = ctc_beam_decode(scorer=self.scorer,
                                      beam_size=beam_size,
                                      vocab=self.labelUtil.byList,
                                      probs=probs)
            self.log.info("decode cost %.2f, result is:\n%s" %
                          (time.time() - st, "\n".join(results)))
        return "\n".join(results)
    log.info("download finished");
    # filename = download.path();
    # pos = filename.rfind("/")
    # if not filename[:pos] == "":
    #     os.startfile(filename[:pos]);  # 打开目录
if __name__ == "__main__":
    from sys import argv, exit

    a = QApplication(argv)

    server = QWebSocketServer(
        "QWebChannel Standalone Server",
        QWebSocketServer.NonSecureMode
    )
    if not server.listen(QHostAddress.LocalHost, 12345):
        log.info("监听端口 12345 失败,客户端已经打开...")
        exit(1)

    clientWrapper = WebSocketClientWrapper(server)

    channel = QWebChannel()
    clientWrapper.clientConnected.connect(channel.connectTo)

    dialog = Dialog()
    channel.registerObject("dialog", dialog)
    #========================初始化OMServer 连接=============================
    omClient=OMClient("omClient");
    omClient.start();
    # url=QUrl.fromLocalFile("./index.html");
    # url.setQuery("webChannelBaseUrl="+server.serverUrl().toString())
    # QDesktopServices.openUrl(url);
Пример #36
0
def main():
    if len(sys.argv)<2:
        print("Usage: tgw.py config_file")
        sys.exit(0)

    #read mdgw connection config
    config_file=sys.argv[1]
        
    run_status=multiprocessing.Value('i', 0)#0:运行;1:退出

    message_header_struct = struct.Struct('!II')
    logon_struct = struct.Struct('!20s20sI16s32s')
    message_footer_struct = struct.Struct('!I')
    
    send_buff = ctypes.create_string_buffer(message_header_struct.size+logon_struct.size+message_footer_struct.size)
    
    bodyLength = logon_struct.size
    message_header =(1, bodyLength)
    message_header_struct.pack_into(send_buff, 0, *message_header)    
   
    
    LogUtil.get_instance(config_file, "log")
    LogUtil.info("Begin")
    
    config=configparser.ConfigParser()
    config.read(config_file)
    sender_comp=config.get("tgw","sender_comp")
    target_comp=config.get("tgw","target_comp")    
    password=config.get("tgw","password")
    app_ver_id=config.get("tgw","app_ver_id")    
    
    sender_comp = str.encode(sender_comp.ljust(20))
    target_comp = str.encode(target_comp.ljust(20))
    password = str.encode(password.ljust(16))
    app_ver_id = str.encode(app_ver_id.ljust(32))    
    
    logon_body = (sender_comp, target_comp, 30, password, app_ver_id)
    logon_struct.pack_into(send_buff, message_header_struct.size, *logon_body)
    check_sum = msg.calculate_check_sum(send_buff, message_header_struct.size+logon_struct.size)
    message_footer_struct.pack_into(send_buff, message_header_struct.size+logon_struct.size, check_sum)
    
    sock = socket.socket(socket.AF_INET,  socket.SOCK_STREAM)
    server_ip=config.get("tgw", "ip")
    server_port=config.getint("tgw", "port")

    #logger initialize
    
    server_address = (server_ip, server_port)
    sock.connect(server_address)
    sock.settimeout(5)
    sock.setblocking(True)
    try:
        LogUtil.debug(binascii.hexlify(send_buff))
        sock.sendall(send_buff)
        recv_data = sock.recv(1024)
        if not recv_data:
            LogUtil.error('Recv error')
        else:
            LogUtil.info('Recv OK')
            LogUtil.info(binascii.hexlify(recv_data))
            unpack_recv_data = message_header_struct.unpack_from(recv_data)
            LogUtil.info(unpack_recv_data)
            #print(binascii.hexlify(recv_data))
            if unpack_recv_data[0]==1:
                LogUtil.info('Receive Login Confirm!')
                
                #TODO:send report sync
                factory=msg.MessageProcessorFactory()
                db_file=config.get("message_config", "db_file")
                factory.load_from_db(db_file, [])
                message_processor=factory.build_message_processor(5)
                buff=message_processor.pack("ReportIndex=1")
                sock.sendall(buff)
                
                req_queue=multiprocessing.Queue()
                resp_queue=multiprocessing.Queue()
                
                dbreader_proc=multiprocessing.Process(target=db_reader, args=('DBReader', run_status, req_queue, config_file))
                dbwriter_proc=multiprocessing.Process(target=db_writer, args=('DBWriter', run_status, resp_queue, config_file))
                send_proc=multiprocessing.Process(target=tgw_send, args=('TGW sender', run_status, sock, req_queue, config_file))
                recv_proc=multiprocessing.Process(target=tgw_recv, args=('TGW receiver', run_status, sock, resp_queue, config_file))
                
                dbreader_proc.start()
                dbwriter_proc.start()
                send_proc.start()
                recv_proc.start()
                
                time.sleep(10)
                
                cmd=input("enter command:")
                while cmd!='q':
                    time.sleep(2)
                    cmd=input("enter command:")
                    
                LogUtil.warning("sending exit cmd")
                run_status.value=1    
                
                dbreader_proc.join()
                dbwriter_proc.join()
                send_proc.join()
                recv_proc.join()
             
                #发送退出消息并处理应答
                logout_message=msg.packLogoutMessage()
                sock.sendall(logout_message)
                recv_data = sock.recv(1024)
                if not recv_data:
                    LogUtil.error('Recv logout_message error!')
                else:
                    LogUtil.info('Recv logout_message OK')
                    LogUtil.debug(binascii.hexlify(recv_data))
            
    finally:
        sock.close()
        LogUtil.info ('End')
Пример #37
0
def db_writer(name, status, resp_queue, config_file):
    LogUtil.get_instance(config_file, "db_writer")
    LogUtil.info("db_writer:"+name+" begin")
    
    config=configparser.ConfigParser()
    config.read(config_file)
    
    factory=msg.MessageProcessorFactory()
    db_file=config.get("message_config", "db_file")
    factory.load_from_db(db_file, [])
    
    read_timeout=config.getint("db_writer", "resp_queue_timeout")
    LogUtil.info("read_timeout:"+str(read_timeout))
    if read_timeout<=0 or read_timeout>60:
        read_timeout=60
        LogUtil.info("read_timeout changed to:"+str(read_timeout))
    
    host=config.get("reqresp", "host")
    database=config.get("reqresp", "database")
    user=config.get("reqresp","user")
    password=config.get("reqresp", "password")
    
    conn=pgdb.connect(database=database, host=host, user=user, password=password)
    update_curs=conn.cursor()
    
    update_resp="""UPDATE req_resp 
    SET rept_status=%(rept_status)s,ex_order_status=%(ex_order_status)s, 
    err_code=%(err_code)s, resp_text=%(resp_text)s, resp_time=localtimestamp 
    WHERE oms_order_id=%(oms_order_id)s    
    """ 
    update_dict={'rept_status':'', 'ex_order_status':'', 'err_code':'', 'resp_text':'', 'oms_order_id':''}
    left_buff=b''
    while status.value==0:
        #TODO:refactor,try recv; and then process the buff
        #TODO:when processing buff, abstract the condition to next_message_ready()
        try:
            recv_buff=resp_queue.get(block=True, timeout=read_timeout)
            left_buff=left_buff+recv_buff
            if len(left_buff)<Message.header_len+Message.header_len:
                continue
            (message_type, body_len)=msg.get_message_header(left_buff)
            next_message_len=body_len+ Message.header_len+Message.footer_len
            while next_message_len<=len(left_buff):
                try:
                    message_processor=factory.build_message_processor(message_type)
                    message=message_processor.unpack(left_buff)
                    LogUtil.debug("message:"+message.toString())
                    if True:#TODO placeholder, check if was order execution report
                        update_dict['rept_status']='4'
                        update_dict['ex_order_status']=message.order_status
                        update_dict['err_code']=message.order_reject_reason
                        update_dict['resp_text']=message.message_str
                        update_dict['oms_order_id']=message.client_order_id
                        
                        update_curs.execute(update_resp, update_dict)
                        if update_curs.rowcount!=1:
                            #TODO error handle, rollback?
                            LogUtil.error("no data update"+message.toString())
                        else:
                            conn.commit()                    
                except KeyError:
                    LogUtil.error("unkown message type:"+str(message_type))
                left_buff=left_buff[next_message_len:]
                if len(left_buff)<Message.header_len+Message.footer_len:
                    break
                else:
                    (message_type, body_len)=msg.get_message_header(left_buff)
                    next_message_len=body_len+ Message.header_len+Message.footer_len
        except queue.Empty:
            LogUtil.debug("resp_queue no data")
#        except KeyError:
#            LogUtil.error("unkown message type:"+str(message_type))
        else:    
            LogUtil.info("db_writer finished processing:"+message.toString())
    LogUtil.info("db_writer:"+name+" end")
Пример #38
0
class Net(object):
    def __init__(self):
        if len(sys.argv) <= 1:
            raise Exception('cfg file path must be provided. ' +
                            'ex)python main.py --configfile examplecfg.cfg')
        self.args = parse_args(sys.argv[1])
        # set parameters from cfg file
        # give random seed
        self.random_seed = self.args.config.getint('common', 'random_seed')
        self.mx_random_seed = self.args.config.getint('common',
                                                      'mx_random_seed')
        # random seed for shuffling data list
        if self.random_seed != -1:
            np.random.seed(self.random_seed)
        # set mx.random.seed to give seed for parameter initialization
        if self.mx_random_seed != -1:
            mx.random.seed(self.mx_random_seed)
        else:
            mx.random.seed(hash(datetime.now()))
        # set log file name
        self.log_filename = self.args.config.get('common', 'log_filename')
        self.log = LogUtil(filename=self.log_filename).getlogger()

        # set parameters from data section(common)
        self.mode = self.args.config.get('common', 'mode')

        save_dir = 'checkpoints'
        model_name = self.args.config.get('common', 'prefix')
        max_freq = self.args.config.getint('data', 'max_freq')
        self.datagen = DataGenerator(save_dir=save_dir,
                                     model_name=model_name,
                                     max_freq=max_freq)
        self.datagen.get_meta_from_file(
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

        self.buckets = json.loads(self.args.config.get('arch', 'buckets'))

        # get meta file where character to number conversions are defined

        self.contexts = parse_contexts(self.args)
        self.num_gpu = len(self.contexts)
        self.batch_size = self.args.config.getint('common', 'batch_size')
        # check the number of gpus is positive divisor of the batch size for data parallel
        self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm')
        self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing')

        # log current config
        self.config_logger = ConfigLogger(self.log)
        self.config_logger(self.args.config)

        default_bucket_key = 1600
        self.args.config.set('arch', 'max_t_count', str(default_bucket_key))
        self.args.config.set('arch', 'max_label_length', str(95))
        self.labelUtil = LabelUtil()
        is_bi_graphemes = self.args.config.getboolean('common',
                                                      'is_bi_graphemes')
        load_labelutil(self.labelUtil, is_bi_graphemes, language="zh")
        self.args.config.set('arch', 'n_classes',
                             str(self.labelUtil.get_count()))
        self.max_t_count = self.args.config.getint('arch', 'max_t_count')
        # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states')

        # load model
        self.model_loaded, self.model_num_epoch, self.model_path = load_model(
            self.args)

        # self.model = STTBucketingModule(
        #     sym_gen=self.model_loaded,
        #     default_bucket_key=default_bucket_key,
        #     context=self.contexts
        # )

        from importlib import import_module
        prepare_data_template = import_module(
            self.args.config.get('arch', 'arch_file'))
        init_states = prepare_data_template.prepare_data(self.args)
        width = self.args.config.getint('data', 'width')
        height = self.args.config.getint('data', 'height')
        for bucket in self.buckets:
            net, init_state_names, ll = self.model_loaded(bucket)
            net.save('checkpoints/%s-symbol.json' % bucket)
        input_shapes = dict([('data',
                              (self.batch_size, default_bucket_key,
                               width * height))] + init_states + [('label',
                                                                   (1, 18))])
        # self.executor = net.simple_bind(ctx=mx.cpu(), **input_shapes)

        # self.model.bind(data_shapes=[('data', (self.batch_size, default_bucket_key, width * height))] + init_states,
        #                 label_shapes=[
        #                     ('label', (self.batch_size, self.args.config.getint('arch', 'max_label_length')))],
        #                 for_training=True)

        symbol, self.arg_params, self.aux_params = mx.model.load_checkpoint(
            self.model_path, self.model_num_epoch)
        all_layers = symbol.get_internals()
        concat = all_layers['concat36457_output']
        sm = mx.sym.SoftmaxOutput(data=concat, name='softmax')
        self.executor = sm.simple_bind(ctx=mx.cpu(), **input_shapes)
        # self.model.set_params(self.arg_params, self.aux_params, allow_extra=True, allow_missing=True)

        for key in self.executor.arg_dict.keys():
            if key in self.arg_params:
                self.arg_params[key].copyto(self.executor.arg_dict[key])
        init_state_names.remove('data')
        init_state_names.sort()
        self.states_dict = dict(
            zip(init_state_names, self.executor.outputs[1:]))
        self.input_arr = mx.nd.zeros(
            (self.batch_size, default_bucket_key, width * height))

        try:
            from swig_wrapper import Scorer

            vocab_list = [
                chars.encode("utf-8") for chars in self.labelUtil.byList
            ]
            self.log.info("vacab_list len is %d" % len(vocab_list))
            _ext_scorer = Scorer(0.26, 0.1,
                                 self.args.config.get('common', 'kenlm'),
                                 vocab_list)
            lm_char_based = _ext_scorer.is_character_based()
            lm_max_order = _ext_scorer.get_max_order()
            lm_dict_size = _ext_scorer.get_dict_size()
            self.log.info("language model: "
                          "is_character_based = %d," % lm_char_based +
                          " max_order = %d," % lm_max_order +
                          " dict_size = %d" % lm_dict_size)
            self.eval_metric = EvalSTTMetric(batch_size=self.batch_size,
                                             num_gpu=self.num_gpu,
                                             is_logging=True,
                                             scorer=_ext_scorer)
        except ImportError:
            import kenlm
            km = kenlm.Model(self.args.config.get('common', 'kenlm'))
            self.eval_metric = EvalSTTMetric(batch_size=self.batch_size,
                                             num_gpu=self.num_gpu,
                                             is_logging=True,
                                             scorer=km.score)

    def forward(self, input_data, new_seq=False):
        if new_seq == True:
            for key in self.states_dict.keys():
                self.executor.arg_dict[key][:] = 0.
        input_data.copyto(self.executor.arg_dict["data"])
        self.executor.forward()
        for key in self.states_dict.keys():
            self.states_dict[key].copyto(self.executor.arg_dict[key])
        prob = self.executor.outputs[0].asnumpy()
        return prob

    def getTrans(self, wav_file):
        res = spectrogram_from_file(wav_file, noise_percent=0)
        buck = bisect.bisect_left(self.buckets, len(res))
        bucket_key = 1600
        res = self.datagen.normalize(res)
        d = np.zeros((self.batch_size, bucket_key, res.shape[1]))
        d[0, :res.shape[0], :] = res
        st = time.time()
        # model_loaded.forward(data_batch, is_train=False)
        probs = self.forward(mx.nd.array(d))
        from stt_metric import ctc_greedy_decode
        res = ctc_greedy_decode(probs, self.labelUtil.byList)
        self.log.info("forward cost %.2f, %s" % (time.time() - st, res))
        st = time.time()
        # model_loaded.update_metric(self.eval_metric, data_batch.label)
        self.log.info("upate metric cost %.2f" % (time.time() - st))
        # print("my res is:")
        # print(eval_metric.placeholder)
        return self.eval_metric.placeholder