Exemple #1
0
def remove_space(l):
    labelUtil = LabelUtil()
    ret = []
    for i in range(len(l)):
        if l[i] != labelUtil.get_space_index():
            ret.append(l[i])
    return ret
Exemple #2
0
    def update(self, labels, preds):
        check_label_shapes(labels, preds)
        if self.is_logging:
            log = LogUtil().getlogger()
            labelUtil = LabelUtil()
        self.batch_loss = 0.
        # log.info(self.audio_paths)
        host_name = socket.gethostname()
        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()

            seq_length = len(pred) / int(
                int(self.batch_size) / int(self.num_gpu))

            for i in range(int(int(self.batch_size) / int(self.num_gpu))):
                l = remove_blank(label[i])
                p = []
                probs = []
                for k in range(int(seq_length)):
                    p.append(
                        np.argmax(pred[
                            k * int(int(self.batch_size) / int(self.num_gpu)) +
                            i]))
                    probs.append(
                        pred[k * int(int(self.batch_size) / int(self.num_gpu))
                             + i])
                p = pred_best(p)

                l_distance = levenshtein_distance(l, p)
                # l_distance = editdistance.eval(labelUtil.convert_num_to_word(l).split(" "), res)
                self.total_n_label += len(l)
                self.total_l_dist += l_distance
                this_cer = float(l_distance) / float(len(l))
                if self.is_logging and this_cer > 0.4:
                    log.info("%s label: %s " %
                             (host_name, labelUtil.convert_num_to_word(l)))
                    log.info(
                        "%s pred : %s , cer: %f (distance: %d/ label length: %d)"
                        % (host_name, labelUtil.convert_num_to_word(p),
                           this_cer, l_distance, len(l)))
                    # log.info("ctc_loss: %.2f" % ctc_loss(l, pred, i, int(seq_length), int(self.batch_size), int(self.num_gpu)))
                self.num_inst += 1
                self.sum_metric += this_cer
                # if self.is_epoch_end:
                #    loss = ctc_loss(l, pred, i, int(seq_length), int(self.batch_size), int(self.num_gpu))
                #    self.batch_loss += loss
                #    if self.is_logging:
                #        log.info("loss: %f " % loss)
        self.total_ctc_loss += 0  # self.batch_loss
Exemple #3
0
    def update(self, labels, preds):
        check_label_shapes(labels, preds)
        if self.is_logging:
            log = LogUtil().getlogger()
            labelUtil = LabelUtil.getInstance()
        self.batch_loss = 0.
        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()

            for i in range(int(int(self.batch_size) / int(self.num_gpu))):

                l = remove_blank(label[i])
                p = []
                for k in range(int(self.seq_length)):
                    p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i]))
                p = pred_best(p)

                l_distance = levenshtein_distance(l, p)
                self.total_n_label += len(l)
                self.total_l_dist += l_distance
                this_cer = float(l_distance) / float(len(l))
                if self.is_logging:
                    log.info("label: %s " % (labelUtil.convert_num_to_word(l)))
                    log.info("pred : %s , cer: %f (distance: %d/ label length: %d)" % (
                        labelUtil.convert_num_to_word(p), this_cer, l_distance, len(l)))
                self.num_inst += 1
                self.sum_metric += this_cer
                if self.is_epoch_end:
                    loss = ctc_loss(l, pred, i, int(self.seq_length), int(self.batch_size), int(self.num_gpu))
                    self.batch_loss += loss
                    if self.is_logging:
                        log.info("loss: %f " % loss)
        self.total_ctc_loss += self.batch_loss
Exemple #4
0
def remove_space(l):
    labelUtil = LabelUtil.getInstance()
    ret = []
    for i in range(len(l)):
        if l[i] != labelUtil.get_space_index():
            ret.append(l[i])
    return ret
Exemple #5
0
    def update(self, labels, preds):
        check_label_shapes(labels, preds)

        log = LogUtil().getlogger()
        labelUtil = LabelUtil.getInstance()

        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()

            for i in range(int(int(self.batch_size) / int(self.num_gpu))):

                l = remove_blank(label[i])
                p = []
                for k in range(int(self.seq_length)):
                    p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i]))
                p = pred_best(p)

                l_distance = levenshtein_distance(l, p)
                self.total_n_label += len(l)
                self.total_l_dist += l_distance
                this_cer = float(l_distance) / float(len(l))
                log.info("label: %s " % (labelUtil.convert_num_to_word(l)))
                log.info("pred : %s , cer: %f (distance: %d/ label length: %d)" % (
                    labelUtil.convert_num_to_word(p), this_cer, l_distance, len(l)))
                self.num_inst += 1
                self.sum_metric += this_cer
            if self.is_epoch_end:
                loss = ctc_loss(l, pred, i, int(self.seq_length), int(self.batch_size), int(self.num_gpu))
                self.total_ctc_loss += loss
                log.info("loss: %f " % loss)
Exemple #6
0
 def prepare_minibatch(self,
                       audio_paths,
                       texts,
                       overwrite=False,
                       is_bi_graphemes=False,
                       seq_length=-1,
                       save_feature_as_csvfile=False):
     """ Featurize a minibatch of audio, zero pad them and return a dictionary
     Params:
         audio_paths (list(str)): List of paths to audio files
         texts (list(str)): List of texts corresponding to the audio files
     Returns:
         dict: See below for contents
     """
     assert len(audio_paths) == len(
         texts
     ), "Inputs and outputs to the network must be of the same number"
     # Features is a list of (timesteps, feature_dim) arrays
     # Calculate the features for each audio clip, as the log of the
     # Fourier Transform of the audio
     features = [
         self.featurize(a,
                        overwrite=overwrite,
                        save_feature_as_csvfile=save_feature_as_csvfile)
         for a in audio_paths
     ]
     input_lengths = [f.shape[0] for f in features]
     feature_dim = features[0].shape[1]
     mb_size = len(features)
     # Pad all the inputs so that they are all the same length
     if seq_length == -1:
         x = np.zeros((mb_size, self.max_seq_length, feature_dim))
     else:
         x = np.zeros((mb_size, seq_length, feature_dim))
     y = np.zeros((mb_size, self.max_label_length))
     labelUtil = LabelUtil.getInstance()
     label_lengths = []
     for i in range(mb_size):
         feat = features[i]
         feat = self.normalize(feat)  # Center using means and std
         x[i, :feat.shape[0], :] = feat
         if is_bi_graphemes:
             label = generate_bi_graphemes_label(texts[i])
             label = labelUtil.convert_bi_graphemes_to_num(label)
             y[i, :len(label)] = label
         else:
             label = labelUtil.convert_word_to_num(texts[i])
             y[i, :len(texts[i])] = label
         label_lengths.append(len(label))
     return {
         'x': x,  # (0-padded features of shape(mb_size,timesteps,feat_dim)
         'y': y,  # list(int) Flattened labels (integer sequences)
         'texts': texts,  # list(str) Original texts
         'input_lengths': input_lengths,  # list(int) Length of each input
         'label_lengths': label_lengths,  # list(int) Length of each label
     }
 def prepare_minibatch(self, audio_paths, texts, overwrite=False,
                       is_bi_graphemes=False, seq_length=-1, save_feature_as_csvfile=False):
     """ Featurize a minibatch of audio, zero pad them and return a dictionary
     Params:
         audio_paths (list(str)): List of paths to audio files
         texts (list(str)): List of texts corresponding to the audio files
     Returns:
         dict: See below for contents
     """
     assert len(audio_paths) == len(texts),\
         "Inputs and outputs to the network must be of the same number"
     # Features is a list of (timesteps, feature_dim) arrays
     # Calculate the features for each audio clip, as the log of the
     # Fourier Transform of the audio
     features = [self.featurize(a, overwrite=overwrite, save_feature_as_csvfile=save_feature_as_csvfile) for a in audio_paths]
     input_lengths = [f.shape[0] for f in features]
     feature_dim = features[0].shape[1]
     mb_size = len(features)
     # Pad all the inputs so that they are all the same length
     if seq_length == -1:
         x = np.zeros((mb_size, self.max_seq_length, feature_dim))
     else:
         x = np.zeros((mb_size, seq_length, feature_dim))
     y = np.zeros((mb_size, self.max_label_length))
     labelUtil = LabelUtil.getInstance()
     label_lengths = []
     for i in range(mb_size):
         feat = features[i]
         feat = self.normalize(feat)  # Center using means and std
         x[i, :feat.shape[0], :] = feat
         if is_bi_graphemes:
             label = generate_bi_graphemes_label(texts[i])
             label = labelUtil.convert_bi_graphemes_to_num(label)
             y[i, :len(label)] = label
         else:
             label = labelUtil.convert_word_to_num(texts[i])
             y[i, :len(texts[i])] = label
         label_lengths.append(len(label))
     return {
         'x': x,  # (0-padded features of shape(mb_size,timesteps,feat_dim)
         'y': y,  # list(int) Flattened labels (integer sequences)
         'texts': texts,  # list(str) Original texts
         'input_lengths': input_lengths,  # list(int) Length of each input
         'label_lengths': label_lengths,  # list(int) Length of each label
     }
Exemple #8
0
def get_scorer(alpha=1., beta=1.):
    try:
        from swig_wrapper import Scorer

        labelUtil = LabelUtil()
        vocab_list = [chars.encode("utf-8") for chars in labelUtil.byList]
        log.info("vacab_list len is %d" % len(vocab_list))
        _ext_scorer = Scorer(alpha, beta, args.config.get('common', 'kenlm'),
                             vocab_list)
        lm_char_based = _ext_scorer.is_character_based()
        lm_max_order = _ext_scorer.get_max_order()
        lm_dict_size = _ext_scorer.get_dict_size()
        log.info("language model: "
                 "is_character_based = %d," % lm_char_based +
                 " max_order = %d," % lm_max_order +
                 " dict_size = %d" % lm_dict_size)
        return _ext_scorer
    except ImportError:
        import kenlm
        km = kenlm.Model(args.config.get('common', 'kenlm'))
        return km.score
Exemple #9
0
def load_data(args):
    mode = args.config.get('common', 'mode')
    if mode not in ['train', 'predict', 'load']:
        raise Exception(
            'mode must be the one of the followings - train,predict,load')
    batch_size = args.config.getint('common', 'batch_size')

    whcs = WHCS()
    whcs.width = args.config.getint('data', 'width')
    whcs.height = args.config.getint('data', 'height')
    whcs.channel = args.config.getint('data', 'channel')
    whcs.stride = args.config.getint('data', 'stride')
    save_dir = 'checkpoints'
    model_name = args.config.get('common', 'prefix')
    is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')
    overwrite_meta_files = args.config.getboolean('train',
                                                  'overwrite_meta_files')
    overwrite_bi_graphemes_dictionary = args.config.getboolean(
        'train', 'overwrite_bi_graphemes_dictionary')
    max_duration = args.config.getfloat('data', 'max_duration')
    language = args.config.get('data', 'language')

    log = LogUtil().getlogger()
    labelUtil = LabelUtil.getInstance()
    if mode == "train" or mode == "load":
        data_json = args.config.get('data', 'train_json')
        val_json = args.config.get('data', 'val_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(data_json, max_duration=max_duration)
        datagen.load_validation_data(val_json, max_duration=max_duration)
        if is_bi_graphemes:
            if not os.path.isfile(
                    "resources/unicodemap_en_baidu_bi_graphemes.csv"
            ) or overwrite_bi_graphemes_dictionary:
                load_labelutil(labelUtil=labelUtil,
                               is_bi_graphemes=False,
                               language=language)
                generate_bi_graphemes_dictionary(datagen.train_texts +
                                                 datagen.val_texts)
        load_labelutil(labelUtil=labelUtil,
                       is_bi_graphemes=is_bi_graphemes,
                       language=language)
        args.config.set('arch', 'n_classes', str(labelUtil.get_count()))

        if mode == "train":
            if overwrite_meta_files:
                log.info("Generate mean and std from samples")
                normalize_target_k = args.config.getint(
                    'train', 'normalize_target_k')
                datagen.sample_normalize(normalize_target_k, True)
            else:
                log.info("Read mean and std from meta files")
                datagen.get_meta_from_file(
                    np.loadtxt(
                        generate_file_path(save_dir, model_name,
                                           'feats_mean')),
                    np.loadtxt(
                        generate_file_path(save_dir, model_name, 'feats_std')))
        elif mode == "load":
            # get feat_mean and feat_std to normalize dataset
            datagen.get_meta_from_file(
                np.loadtxt(
                    generate_file_path(save_dir, model_name, 'feats_mean')),
                np.loadtxt(
                    generate_file_path(save_dir, model_name, 'feats_std')))

    elif mode == 'predict':
        test_json = args.config.get('data', 'test_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(test_json, max_duration=max_duration)
        labelutil = load_labelutil(labelUtil, is_bi_graphemes, language="en")
        args.config.set('arch', 'n_classes', str(labelUtil.get_count()))
        datagen.get_meta_from_file(
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

    is_batchnorm = args.config.getboolean('arch', 'is_batchnorm')
    if batch_size == 1 and is_batchnorm and (mode == 'train'
                                             or mode == 'load'):
        raise Warning('batch size 1 is too small for is_batchnorm')

    # sort file paths by its duration in ascending order to implement sortaGrad
    if mode == "train" or mode == "load":
        max_t_count = datagen.get_max_seq_length(partition="train")
        max_label_length = \
            datagen.get_max_label_length(partition="train", is_bi_graphemes=is_bi_graphemes)
    elif mode == "predict":
        max_t_count = datagen.get_max_seq_length(partition="test")
        max_label_length = \
            datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes)

    args.config.set('arch', 'max_t_count', str(max_t_count))
    args.config.set('arch', 'max_label_length', str(max_label_length))
    from importlib import import_module
    prepare_data_template = import_module(args.config.get('arch', 'arch_file'))
    init_states = prepare_data_template.prepare_data(args)
    sort_by_duration = (mode == "train")
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')
    save_feature_as_csvfile = args.config.getboolean(
        'train', 'save_feature_as_csvfile')
    if is_bucketing:
        buckets = json.loads(args.config.get('arch', 'buckets'))
        data_loaded = BucketSTTIter(
            partition="train",
            count=datagen.count,
            datagen=datagen,
            batch_size=batch_size,
            num_label=max_label_length,
            init_states=init_states,
            seq_length=max_t_count,
            width=whcs.width,
            height=whcs.height,
            sort_by_duration=sort_by_duration,
            is_bi_graphemes=is_bi_graphemes,
            buckets=buckets,
            save_feature_as_csvfile=save_feature_as_csvfile)
    else:
        data_loaded = STTIter(partition="train",
                              count=datagen.count,
                              datagen=datagen,
                              batch_size=batch_size,
                              num_label=max_label_length,
                              init_states=init_states,
                              seq_length=max_t_count,
                              width=whcs.width,
                              height=whcs.height,
                              sort_by_duration=sort_by_duration,
                              is_bi_graphemes=is_bi_graphemes,
                              save_feature_as_csvfile=save_feature_as_csvfile)

    if mode == 'train' or mode == 'load':
        if is_bucketing:
            validation_loaded = BucketSTTIter(
                partition="validation",
                count=datagen.val_count,
                datagen=datagen,
                batch_size=batch_size,
                num_label=max_label_length,
                init_states=init_states,
                seq_length=max_t_count,
                width=whcs.width,
                height=whcs.height,
                sort_by_duration=False,
                is_bi_graphemes=is_bi_graphemes,
                buckets=buckets,
                save_feature_as_csvfile=save_feature_as_csvfile)
        else:
            validation_loaded = STTIter(
                partition="validation",
                count=datagen.val_count,
                datagen=datagen,
                batch_size=batch_size,
                num_label=max_label_length,
                init_states=init_states,
                seq_length=max_t_count,
                width=whcs.width,
                height=whcs.height,
                sort_by_duration=False,
                is_bi_graphemes=is_bi_graphemes,
                save_feature_as_csvfile=save_feature_as_csvfile)
        return data_loaded, validation_loaded, args
    elif mode == 'predict':
        return data_loaded, args
Exemple #10
0
def load_data(args):
    mode = args.config.get('common', 'mode')
    batch_size = args.config.getint('common', 'batch_size')

    whcs = WHCS()
    whcs.width = args.config.getint('data', 'width')
    whcs.height = args.config.getint('data', 'height')
    whcs.channel = args.config.getint('data', 'channel')
    whcs.stride = args.config.getint('data', 'stride')
    save_dir = 'checkpoints'
    model_name = args.config.get('common', 'prefix')
    is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')
    overwrite_meta_files = args.config.getboolean('train',
                                                  'overwrite_meta_files')

    if mode == 'predict':
        test_json = args.config.get('data', 'test_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(test_json)
        datagen.get_meta_from_file(
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
    elif mode == "train" or mode == "load":
        data_json = args.config.get('data', 'train_json')
        val_json = args.config.get('data', 'val_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(data_json)
        #test bigramphems

        language = args.config.get('data', 'language')
        is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')

        if overwrite_meta_files and is_bi_graphemes:
            generate_bi_graphemes_dictionary(datagen.train_texts)

        labelUtil = LabelUtil.getInstance()
        if language == "en":
            if is_bi_graphemes:
                try:
                    labelUtil.load_unicode_set(
                        "resources/unicodemap_en_baidu_bi_graphemes.csv")
                except:
                    raise Exception(
                        "There is no resources/unicodemap_en_baidu_bi_graphemes.csv. Please set overwrite_meta_files at train section True"
                    )
            else:
                labelUtil.load_unicode_set("resources/unicodemap_en_baidu.csv")
        else:
            raise Exception("Error: Language Type: %s" % language)
        args.config.set('arch', 'n_classes', str(labelUtil.get_count()))

        if mode == "train":
            if overwrite_meta_files:
                normalize_target_k = args.config.getint(
                    'train', 'normalize_target_k')
                datagen.sample_normalize(normalize_target_k, True)
            else:
                datagen.get_meta_from_file(
                    np.loadtxt(
                        generate_file_path(save_dir, model_name,
                                           'feats_mean')),
                    np.loadtxt(
                        generate_file_path(save_dir, model_name, 'feats_std')))
            datagen.load_validation_data(val_json)

        elif mode == "load":
            # get feat_mean and feat_std to normalize dataset
            datagen.get_meta_from_file(
                np.loadtxt(
                    generate_file_path(save_dir, model_name, 'feats_mean')),
                np.loadtxt(
                    generate_file_path(save_dir, model_name, 'feats_std')))
            datagen.load_validation_data(val_json)
    else:
        raise Exception(
            'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.'
        )

    is_batchnorm = args.config.getboolean('arch', 'is_batchnorm')
    if batch_size == 1 and is_batchnorm:
        raise Warning('batch size 1 is too small for is_batchnorm')

    # sort file paths by its duration in ascending order to implement sortaGrad

    if mode == "train" or mode == "load":
        max_t_count = datagen.get_max_seq_length(partition="train")
        max_label_length = datagen.get_max_label_length(
            partition="train", is_bi_graphemes=is_bi_graphemes)
    elif mode == "predict":
        max_t_count = datagen.get_max_seq_length(partition="test")
        max_label_length = datagen.get_max_label_length(
            partition="test", is_bi_graphemes=is_bi_graphemes)
    else:
        raise Exception(
            'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.'
        )

    args.config.set('arch', 'max_t_count', str(max_t_count))
    args.config.set('arch', 'max_label_length', str(max_label_length))
    from importlib import import_module
    prepare_data_template = import_module(args.config.get('arch', 'arch_file'))
    init_states = prepare_data_template.prepare_data(args)
    if mode == "train":
        sort_by_duration = True
    else:
        sort_by_duration = False

    data_loaded = STTIter(partition="train",
                          count=datagen.count,
                          datagen=datagen,
                          batch_size=batch_size,
                          num_label=max_label_length,
                          init_states=init_states,
                          seq_length=max_t_count,
                          width=whcs.width,
                          height=whcs.height,
                          sort_by_duration=sort_by_duration,
                          is_bi_graphemes=is_bi_graphemes)

    if mode == 'predict':
        return data_loaded, args
    else:
        validation_loaded = STTIter(partition="validation",
                                    count=datagen.val_count,
                                    datagen=datagen,
                                    batch_size=batch_size,
                                    num_label=max_label_length,
                                    init_states=init_states,
                                    seq_length=max_t_count,
                                    width=whcs.width,
                                    height=whcs.height,
                                    sort_by_duration=False,
                                    is_bi_graphemes=is_bi_graphemes)
        return data_loaded, validation_loaded, args
Exemple #11
0
    def __init__(self, args):
        self.args = args
        # set parameters from data section(common)
        self.mode = self.args.config.get('common', 'mode')

        # get meta file where character to number conversions are defined

        self.contexts = parse_contexts(self.args)
        self.num_gpu = len(self.contexts)
        self.batch_size = self.args.config.getint('common', 'batch_size')
        # check the number of gpus is positive divisor of the batch size for data parallel
        self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm')
        self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing')

        # log current config
        self.config_logger = ConfigLogger(log)
        self.config_logger(args.config)

        save_dir = 'checkpoints'
        model_name = self.args.config.get('common', 'prefix')
        max_freq = self.args.config.getint('data', 'max_freq')
        self.datagen = DataGenerator(save_dir=save_dir,
                                     model_name=model_name,
                                     max_freq=max_freq)
        self.datagen.get_meta_from_file(
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

        self.buckets = json.loads(self.args.config.get('arch', 'buckets'))

        default_bucket_key = self.buckets[-1]
        self.args.config.set('arch', 'max_t_count', str(default_bucket_key))
        self.args.config.set('arch', 'max_label_length', str(100))
        self.labelUtil = LabelUtil()
        is_bi_graphemes = self.args.config.getboolean('common',
                                                      'is_bi_graphemes')
        load_labelutil(self.labelUtil, is_bi_graphemes, language="zh")
        self.args.config.set('arch', 'n_classes',
                             str(self.labelUtil.get_count()))
        self.max_t_count = self.args.config.getint('arch', 'max_t_count')
        # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states')

        # load model
        self.model_loaded, self.model_num_epoch, self.model_path = load_model(
            self.args)
        symbol, self.arg_params, self.aux_params = mx.model.load_checkpoint(
            self.model_path, self.model_num_epoch)
        # all_layers = symbol.get_internals()
        # s_sym = all_layers['concat36457_output']
        # sm = mx.sym.SoftmaxOutput(data=s_sym, name='softmax')

        # self.model = STTBucketingModule(
        #     sym_gen=self.model_loaded,
        #     default_bucket_key=default_bucket_key,
        #     context=self.contexts
        # )
        s_mod = mx.mod.BucketingModule(sym_gen=self.model_loaded,
                                       context=self.contexts,
                                       default_bucket_key=default_bucket_key)

        from importlib import import_module
        prepare_data_template = import_module(
            self.args.config.get('arch', 'arch_file'))
        self.init_states = prepare_data_template.prepare_data(self.args)
        self.width = self.args.config.getint('data', 'width')
        self.height = self.args.config.getint('data', 'height')
        s_mod.bind(data_shapes=[
            ('data',
             (self.batch_size, default_bucket_key, self.width * self.height))
        ] + self.init_states,
                   for_training=False)

        s_mod.set_params(self.arg_params,
                         self.aux_params,
                         allow_extra=True,
                         allow_missing=True)
        for bucket in self.buckets:
            provide_data = [
                ('data', (self.batch_size, bucket, self.width * self.height))
            ] + self.init_states
            s_mod.switch_bucket(bucket_key=bucket, data_shapes=provide_data)

        self.model = s_mod

        try:
            from swig_wrapper import Scorer

            vocab_list = [
                chars.encode("utf-8") for chars in self.labelUtil.byList
            ]
            log.info("vacab_list len is %d" % len(vocab_list))
            _ext_scorer = Scorer(0.26, 0.1,
                                 self.args.config.get('common', 'kenlm'),
                                 vocab_list)
            lm_char_based = _ext_scorer.is_character_based()
            lm_max_order = _ext_scorer.get_max_order()
            lm_dict_size = _ext_scorer.get_dict_size()
            log.info("language model: "
                     "is_character_based = %d," % lm_char_based +
                     " max_order = %d," % lm_max_order +
                     " dict_size = %d" % lm_dict_size)
            self.scorer = _ext_scorer
            # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True,
            #                                  scorer=_ext_scorer)
        except ImportError:
            import kenlm
            km = kenlm.Model(self.args.config.get('common', 'kenlm'))
            # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True,
            #                                  scorer=km.score)
            self.scorer = km.score
Exemple #12
0
class Net(object):
    def __init__(self, args):
        self.args = args
        # set parameters from data section(common)
        self.mode = self.args.config.get('common', 'mode')

        # get meta file where character to number conversions are defined

        self.contexts = parse_contexts(self.args)
        self.num_gpu = len(self.contexts)
        self.batch_size = self.args.config.getint('common', 'batch_size')
        # check the number of gpus is positive divisor of the batch size for data parallel
        self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm')
        self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing')

        # log current config
        self.config_logger = ConfigLogger(log)
        self.config_logger(args.config)

        save_dir = 'checkpoints'
        model_name = self.args.config.get('common', 'prefix')
        max_freq = self.args.config.getint('data', 'max_freq')
        self.datagen = DataGenerator(save_dir=save_dir,
                                     model_name=model_name,
                                     max_freq=max_freq)
        self.datagen.get_meta_from_file(
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

        self.buckets = json.loads(self.args.config.get('arch', 'buckets'))

        default_bucket_key = self.buckets[-1]
        self.args.config.set('arch', 'max_t_count', str(default_bucket_key))
        self.args.config.set('arch', 'max_label_length', str(100))
        self.labelUtil = LabelUtil()
        is_bi_graphemes = self.args.config.getboolean('common',
                                                      'is_bi_graphemes')
        load_labelutil(self.labelUtil, is_bi_graphemes, language="zh")
        self.args.config.set('arch', 'n_classes',
                             str(self.labelUtil.get_count()))
        self.max_t_count = self.args.config.getint('arch', 'max_t_count')
        # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states')

        # load model
        self.model_loaded, self.model_num_epoch, self.model_path = load_model(
            self.args)
        symbol, self.arg_params, self.aux_params = mx.model.load_checkpoint(
            self.model_path, self.model_num_epoch)
        # all_layers = symbol.get_internals()
        # s_sym = all_layers['concat36457_output']
        # sm = mx.sym.SoftmaxOutput(data=s_sym, name='softmax')

        # self.model = STTBucketingModule(
        #     sym_gen=self.model_loaded,
        #     default_bucket_key=default_bucket_key,
        #     context=self.contexts
        # )
        s_mod = mx.mod.BucketingModule(sym_gen=self.model_loaded,
                                       context=self.contexts,
                                       default_bucket_key=default_bucket_key)

        from importlib import import_module
        prepare_data_template = import_module(
            self.args.config.get('arch', 'arch_file'))
        self.init_states = prepare_data_template.prepare_data(self.args)
        self.width = self.args.config.getint('data', 'width')
        self.height = self.args.config.getint('data', 'height')
        s_mod.bind(data_shapes=[
            ('data',
             (self.batch_size, default_bucket_key, self.width * self.height))
        ] + self.init_states,
                   for_training=False)

        s_mod.set_params(self.arg_params,
                         self.aux_params,
                         allow_extra=True,
                         allow_missing=True)
        for bucket in self.buckets:
            provide_data = [
                ('data', (self.batch_size, bucket, self.width * self.height))
            ] + self.init_states
            s_mod.switch_bucket(bucket_key=bucket, data_shapes=provide_data)

        self.model = s_mod

        try:
            from swig_wrapper import Scorer

            vocab_list = [
                chars.encode("utf-8") for chars in self.labelUtil.byList
            ]
            log.info("vacab_list len is %d" % len(vocab_list))
            _ext_scorer = Scorer(0.26, 0.1,
                                 self.args.config.get('common', 'kenlm'),
                                 vocab_list)
            lm_char_based = _ext_scorer.is_character_based()
            lm_max_order = _ext_scorer.get_max_order()
            lm_dict_size = _ext_scorer.get_dict_size()
            log.info("language model: "
                     "is_character_based = %d," % lm_char_based +
                     " max_order = %d," % lm_max_order +
                     " dict_size = %d" % lm_dict_size)
            self.scorer = _ext_scorer
            # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True,
            #                                  scorer=_ext_scorer)
        except ImportError:
            import kenlm
            km = kenlm.Model(self.args.config.get('common', 'kenlm'))
            # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True,
            #                                  scorer=km.score)
            self.scorer = km.score

    def getTrans(self, wav_file):
        res = spectrogram_from_file(wav_file, noise_percent=0)
        buck = bisect.bisect_left(self.buckets, len(res))
        bucket_key = self.buckets[buck]
        res = self.datagen.normalize(res)
        d = np.zeros((self.batch_size, bucket_key, res.shape[1]))
        d[0, :res.shape[0], :] = res
        init_state_arrays = [mx.nd.zeros(x[1]) for x in self.init_states]

        model_loaded = self.model

        provide_data = [
            ('data', (self.batch_size, bucket_key, self.width * self.height))
        ] + self.init_states
        data_batch = mx.io.DataBatch([mx.nd.array(d)] + init_state_arrays,
                                     label=None,
                                     bucket_key=bucket_key,
                                     provide_data=provide_data,
                                     provide_label=None)
        st = time.time()
        model_loaded.forward(data_batch, is_train=False)
        probs = model_loaded.get_outputs()[0].asnumpy()
        log.info("forward cost %.3f" % (time.time() - st))
        st = time.time()
        res = ctc_greedy_decode(probs, self.labelUtil.byList)
        log.info("greedy decode cost %.3f, result is:\n%s" %
                 (time.time() - st, res))
        beam_size = 5
        from stt_metric import ctc_beam_decode
        st = time.time()
        results = ctc_beam_decode(scorer=self.scorer,
                                  beam_size=beam_size,
                                  vocab=self.labelUtil.byList,
                                  probs=probs)
        log.info("beam decode cost %.3f, result is:\n%s" %
                 (time.time() - st, "\n".join(results)))
        return "greedy:\n" + res + "\nbeam:\n" + "\n".join(results)
Exemple #13
0
def load_data(args):
    mode = args.config.get('common', 'mode')
    batch_size = args.config.getint('common', 'batch_size')

    whcs = WHCS()
    whcs.width = args.config.getint('data', 'width')
    whcs.height = args.config.getint('data', 'height')
    whcs.channel = args.config.getint('data', 'channel')
    whcs.stride = args.config.getint('data', 'stride')
    save_dir = 'checkpoints'
    model_name = args.config.get('common', 'prefix')
    is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')
    overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files')
    language = args.config.get('data', 'language')
    is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')

    labelUtil = LabelUtil.getInstance()
    if language == "en":
        if is_bi_graphemes:
            try:
                labelUtil.load_unicode_set("resources/unicodemap_en_baidu_bi_graphemes.csv")
            except:
                raise Exception("There is no resources/unicodemap_en_baidu_bi_graphemes.csv. Please set overwrite_meta_files at train section True")
        else:
            labelUtil.load_unicode_set("resources/unicodemap_en_baidu.csv")
    else:
        raise Exception("Error: Language Type: %s" % language)
    args.config.set('arch', 'n_classes', str(labelUtil.get_count()))

    if mode == 'predict':
        test_json = args.config.get('data', 'test_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(test_json)
        datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
                                   np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
    elif mode =="train" or mode == "load":
        data_json = args.config.get('data', 'train_json')
        val_json = args.config.get('data', 'val_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(data_json)
        #test bigramphems

        if overwrite_meta_files and is_bi_graphemes:
            generate_bi_graphemes_dictionary(datagen.train_texts)

        args.config.set('arch', 'n_classes', str(labelUtil.get_count()))

        if mode == "train":
            if overwrite_meta_files:
                normalize_target_k = args.config.getint('train', 'normalize_target_k')
                datagen.sample_normalize(normalize_target_k, True)
            else:
                datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
                                           np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
            datagen.load_validation_data(val_json)

        elif mode == "load":
            # get feat_mean and feat_std to normalize dataset
            datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
                                       np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
            datagen.load_validation_data(val_json)
    else:
        raise Exception(
            'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.')

    is_batchnorm = args.config.getboolean('arch', 'is_batchnorm')
    if batch_size == 1 and is_batchnorm:
        raise Warning('batch size 1 is too small for is_batchnorm')

    # sort file paths by its duration in ascending order to implement sortaGrad

    if mode == "train" or mode == "load":
        max_t_count = datagen.get_max_seq_length(partition="train")
        max_label_length = datagen.get_max_label_length(partition="train",is_bi_graphemes=is_bi_graphemes)
    elif mode == "predict":
        max_t_count = datagen.get_max_seq_length(partition="test")
        max_label_length = datagen.get_max_label_length(partition="test",is_bi_graphemes=is_bi_graphemes)
    else:
        raise Exception(
            'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.')

    args.config.set('arch', 'max_t_count', str(max_t_count))
    args.config.set('arch', 'max_label_length', str(max_label_length))
    from importlib import import_module
    prepare_data_template = import_module(args.config.get('arch', 'arch_file'))
    init_states = prepare_data_template.prepare_data(args)
    if mode == "train":
        sort_by_duration = True
    else:
        sort_by_duration = False

    data_loaded = STTIter(partition="train",
                          count=datagen.count,
                          datagen=datagen,
                          batch_size=batch_size,
                          num_label=max_label_length,
                          init_states=init_states,
                          seq_length=max_t_count,
                          width=whcs.width,
                          height=whcs.height,
                          sort_by_duration=sort_by_duration,
                          is_bi_graphemes=is_bi_graphemes)

    if mode == 'predict':
        return data_loaded, args
    else:
        validation_loaded = STTIter(partition="validation",
                                    count=datagen.val_count,
                                    datagen=datagen,
                                    batch_size=batch_size,
                                    num_label=max_label_length,
                                    init_states=init_states,
                                    seq_length=max_t_count,
                                    width=whcs.width,
                                    height=whcs.height,
                                    sort_by_duration=False,
                                    is_bi_graphemes=is_bi_graphemes)
        return data_loaded, validation_loaded, args
Exemple #14
0
def load_data(args):
    mode = args.config.get('common', 'mode')
    if mode not in ['train', 'predict', 'load']:
        raise Exception('mode must be the one of the followings - train,predict,load')
    batch_size = args.config.getint('common', 'batch_size')

    whcs = WHCS()
    whcs.width = args.config.getint('data', 'width')
    whcs.height = args.config.getint('data', 'height')
    whcs.channel = args.config.getint('data', 'channel')
    whcs.stride = args.config.getint('data', 'stride')
    save_dir = 'checkpoints'
    model_name = args.config.get('common', 'prefix')
    is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')
    overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files')
    overwrite_bi_graphemes_dictionary = args.config.getboolean('train', 'overwrite_bi_graphemes_dictionary')
    max_duration = args.config.getfloat('data', 'max_duration')
    language = args.config.get('data', 'language')

    log = LogUtil().getlogger()
    labelUtil = LabelUtil.getInstance()
    if mode == "train" or mode == "load":
        data_json = args.config.get('data', 'train_json')
        val_json = args.config.get('data', 'val_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(data_json, max_duration=max_duration)
        datagen.load_validation_data(val_json, max_duration=max_duration)
        if is_bi_graphemes:
            if not os.path.isfile("resources/unicodemap_en_baidu_bi_graphemes.csv") or overwrite_bi_graphemes_dictionary:
                load_labelutil(labelUtil=labelUtil, is_bi_graphemes=False, language=language)
                generate_bi_graphemes_dictionary(datagen.train_texts+datagen.val_texts)
        load_labelutil(labelUtil=labelUtil, is_bi_graphemes=is_bi_graphemes, language=language)
        args.config.set('arch', 'n_classes', str(labelUtil.get_count()))

        if mode == "train":
            if overwrite_meta_files:
                log.info("Generate mean and std from samples")
                normalize_target_k = args.config.getint('train', 'normalize_target_k')
                datagen.sample_normalize(normalize_target_k, True)
            else:
                log.info("Read mean and std from meta files")
                datagen.get_meta_from_file(
                    np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
                    np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
        elif mode == "load":
            # get feat_mean and feat_std to normalize dataset
            datagen.get_meta_from_file(
                np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
                np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

    elif mode == 'predict':
        test_json = args.config.get('data', 'test_json')
        datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
        datagen.load_train_data(test_json, max_duration=max_duration)
        labelutil = load_labelutil(labelUtil, is_bi_graphemes, language="en")
        args.config.set('arch', 'n_classes', str(labelUtil.get_count()))
        datagen.get_meta_from_file(
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

    is_batchnorm = args.config.getboolean('arch', 'is_batchnorm')
    if batch_size == 1 and is_batchnorm and (mode == 'train' or mode == 'load'):
        raise Warning('batch size 1 is too small for is_batchnorm')

    # sort file paths by its duration in ascending order to implement sortaGrad
    if mode == "train" or mode == "load":
        max_t_count = datagen.get_max_seq_length(partition="train")
        max_label_length = \
            datagen.get_max_label_length(partition="train", is_bi_graphemes=is_bi_graphemes)
    elif mode == "predict":
        max_t_count = datagen.get_max_seq_length(partition="test")
        max_label_length = \
            datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes)

    args.config.set('arch', 'max_t_count', str(max_t_count))
    args.config.set('arch', 'max_label_length', str(max_label_length))
    from importlib import import_module
    prepare_data_template = import_module(args.config.get('arch', 'arch_file'))
    init_states = prepare_data_template.prepare_data(args)
    sort_by_duration = (mode == "train")
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')
    save_feature_as_csvfile = args.config.getboolean('train', 'save_feature_as_csvfile')
    if is_bucketing:
        buckets = json.loads(args.config.get('arch', 'buckets'))
        data_loaded = BucketSTTIter(partition="train",
                                    count=datagen.count,
                                    datagen=datagen,
                                    batch_size=batch_size,
                                    num_label=max_label_length,
                                    init_states=init_states,
                                    seq_length=max_t_count,
                                    width=whcs.width,
                                    height=whcs.height,
                                    sort_by_duration=sort_by_duration,
                                    is_bi_graphemes=is_bi_graphemes,
                                    buckets=buckets,
                                    save_feature_as_csvfile=save_feature_as_csvfile)
    else:
        data_loaded = STTIter(partition="train",
                              count=datagen.count,
                              datagen=datagen,
                              batch_size=batch_size,
                              num_label=max_label_length,
                              init_states=init_states,
                              seq_length=max_t_count,
                              width=whcs.width,
                              height=whcs.height,
                              sort_by_duration=sort_by_duration,
                              is_bi_graphemes=is_bi_graphemes,
                              save_feature_as_csvfile=save_feature_as_csvfile)

    if mode == 'train' or mode == 'load':
        if is_bucketing:
            validation_loaded = BucketSTTIter(partition="validation",
                                              count=datagen.val_count,
                                              datagen=datagen,
                                              batch_size=batch_size,
                                              num_label=max_label_length,
                                              init_states=init_states,
                                              seq_length=max_t_count,
                                              width=whcs.width,
                                              height=whcs.height,
                                              sort_by_duration=False,
                                              is_bi_graphemes=is_bi_graphemes,
                                              buckets=buckets,
                                              save_feature_as_csvfile=save_feature_as_csvfile)
        else:
            validation_loaded = STTIter(partition="validation",
                                        count=datagen.val_count,
                                        datagen=datagen,
                                        batch_size=batch_size,
                                        num_label=max_label_length,
                                        init_states=init_states,
                                        seq_length=max_t_count,
                                        width=whcs.width,
                                        height=whcs.height,
                                        sort_by_duration=False,
                                        is_bi_graphemes=is_bi_graphemes,
                                        save_feature_as_csvfile=save_feature_as_csvfile)
        return data_loaded, validation_loaded, args
    elif mode == 'predict':
        return data_loaded, args
Exemple #15
0
 def prepare_minibatch_fbank(self,
                             audio_paths,
                             texts,
                             overwrite=False,
                             is_bi_graphemes=False,
                             seq_length=-1,
                             save_feature_as_csvfile=False,
                             language="en",
                             zh_type="zi",
                             noise_percent=0.4):
     """ Featurize a minibatch of audio, zero pad them and return a dictionary
     Params:
         audio_paths (list(str)): List of paths to audio files
         texts (list(str)): List of texts corresponding to the audio files
     Returns:
         dict: See below for contents
     """
     assert len(audio_paths) == len(texts), \
         "Inputs and outputs to the network must be of the same number"
     # Features is a list of (timesteps, feature_dim(161)) arrays (channel(3), feature_dim(41), timesteps)
     # Calculate the features for each audio clip, as the log of the
     # Fourier Transform of the audio
     features = [
         self.featurize_fbank(
             a,
             overwrite=overwrite,
             save_feature_as_csvfile=save_feature_as_csvfile,
             noise_percent=noise_percent,
             seq_length=seq_length) for a in audio_paths
     ]
     input_lengths = [f.shape[1] for f in features]
     channel, timesteps, feature_dim = features[0].shape
     mb_size = len(features)
     # Pad all the inputs so that they are all the same length
     if seq_length == -1:
         x = np.zeros((mb_size, channel, self.max_seq_length, feature_dim))
     else:
         x = np.zeros((mb_size, channel, seq_length, feature_dim))
     y = np.zeros((mb_size, self.max_label_length))
     labelUtil = LabelUtil()
     label_lengths = []
     for i in range(mb_size):
         feat = features[i]
         feat = self.normalize_fbank(feat)  # Center using means and std
         x[i, :, :feat.
           shape[1], :] = feat  # padding with 0 padding with noise?
         if language == "en" and is_bi_graphemes:
             label = generate_bi_graphemes_label(texts[i])
             label = labelUtil.convert_bi_graphemes_to_num(label)
             y[i, :len(label)] = label
         elif language == "en" and not is_bi_graphemes:
             label = labelUtil.convert_word_to_num(texts[i])
             y[i, :len(texts[i])] = label
         elif language == "zh" and zh_type == "phone":
             label = generate_phone_label(texts[i])
             label = labelUtil.convert_bi_graphemes_to_num(label)
             y[i, :len(label)] = label
         elif language == "zh" and zh_type == "py":
             label = generate_py_label(texts[i])
             label = labelUtil.convert_bi_graphemes_to_num(label)
             y[i, :len(label)] = label
         elif language == "zh" and zh_type == "zi":
             label = generate_zi_label(texts[i])
             label = labelUtil.convert_bi_graphemes_to_num(label)
             y[i, :len(label)] = label
         label_lengths.append(len(label))
     return {
         'x': x,  # (0-padded features of shape(mb_size,timesteps,feat_dim)
         'y': y,  # list(int) Flattened labels (integer sequences)
         'texts': texts,  # list(str) Original texts
         'input_lengths': input_lengths,  # list(int) Length of each input
         'label_lengths': label_lengths,  # list(int) Length of each label
     }
Exemple #16
0
    def __init__(self):
        if len(sys.argv) <= 1:
            raise Exception('cfg file path must be provided. ' +
                            'ex)python main.py --configfile examplecfg.cfg')
        self.args = parse_args(sys.argv[1])
        # set parameters from cfg file
        # give random seed
        self.random_seed = self.args.config.getint('common', 'random_seed')
        self.mx_random_seed = self.args.config.getint('common',
                                                      'mx_random_seed')
        # random seed for shuffling data list
        if self.random_seed != -1:
            np.random.seed(self.random_seed)
        # set mx.random.seed to give seed for parameter initialization
        if self.mx_random_seed != -1:
            mx.random.seed(self.mx_random_seed)
        else:
            mx.random.seed(hash(datetime.now()))
        # set log file name
        self.log_filename = self.args.config.get('common', 'log_filename')
        self.log = LogUtil(filename=self.log_filename).getlogger()

        # set parameters from data section(common)
        self.mode = self.args.config.get('common', 'mode')

        # get meta file where character to number conversions are defined

        self.contexts = parse_contexts(self.args)
        self.num_gpu = len(self.contexts)
        self.batch_size = self.args.config.getint('common', 'batch_size')
        # check the number of gpus is positive divisor of the batch size for data parallel
        self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm')
        self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing')

        # log current config
        self.config_logger = ConfigLogger(self.log)
        self.config_logger(self.args.config)

        default_bucket_key = 1600
        self.args.config.set('arch', 'max_t_count', str(default_bucket_key))
        self.args.config.set('arch', 'max_label_length', str(100))
        self.labelUtil = LabelUtil()
        is_bi_graphemes = self.args.config.getboolean('common',
                                                      'is_bi_graphemes')
        load_labelutil(self.labelUtil, is_bi_graphemes, language="zh")
        self.args.config.set('arch', 'n_classes',
                             str(self.labelUtil.get_count()))
        self.max_t_count = self.args.config.getint('arch', 'max_t_count')
        # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states')

        # load model
        self.model_loaded, self.model_num_epoch, self.model_path = load_model(
            self.args)

        self.model = STTBucketingModule(sym_gen=self.model_loaded,
                                        default_bucket_key=default_bucket_key,
                                        context=self.contexts)

        from importlib import import_module
        prepare_data_template = import_module(
            self.args.config.get('arch', 'arch_file'))
        init_states = prepare_data_template.prepare_data(self.args)
        width = self.args.config.getint('data', 'width')
        height = self.args.config.getint('data', 'height')
        self.model.bind(data_shapes=[
            ('data', (self.batch_size, default_bucket_key, width * height))
        ] + init_states,
                        label_shapes=[
                            ('label',
                             (self.batch_size,
                              self.args.config.getint('arch',
                                                      'max_label_length')))
                        ],
                        for_training=True)

        _, self.arg_params, self.aux_params = mx.model.load_checkpoint(
            self.model_path, self.model_num_epoch)
        self.model.set_params(self.arg_params,
                              self.aux_params,
                              allow_extra=True,
                              allow_missing=True)

        try:
            from swig_wrapper import Scorer

            vocab_list = [
                chars.encode("utf-8") for chars in self.labelUtil.byList
            ]
            self.log.info("vacab_list len is %d" % len(vocab_list))
            _ext_scorer = Scorer(0.26, 0.1,
                                 self.args.config.get('common', 'kenlm'),
                                 vocab_list)
            lm_char_based = _ext_scorer.is_character_based()
            lm_max_order = _ext_scorer.get_max_order()
            lm_dict_size = _ext_scorer.get_dict_size()
            self.log.info("language model: "
                          "is_character_based = %d," % lm_char_based +
                          " max_order = %d," % lm_max_order +
                          " dict_size = %d" % lm_dict_size)
            self.scorer = _ext_scorer
            # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True,
            #                                  scorer=_ext_scorer)
        except ImportError:
            import kenlm
            km = kenlm.Model(self.args.config.get('common', 'kenlm'))
            self.scorer = km.score
Exemple #17
0
def load_data(args, wav_file):
    mode = args.config.get('common', 'mode')
    if mode not in ['train', 'predict', 'load']:
        raise Exception(
            'mode must be the one of the followings - train,predict,load')
    batch_size = args.config.getint('common', 'batch_size')

    whcs = WHCS()
    whcs.width = args.config.getint('data', 'width')
    whcs.height = args.config.getint('data', 'height')
    whcs.channel = args.config.getint('data', 'channel')
    whcs.stride = args.config.getint('data', 'stride')
    save_dir = 'checkpoints'
    model_name = args.config.get('common', 'prefix')
    is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')
    overwrite_meta_files = args.config.getboolean('train',
                                                  'overwrite_meta_files')
    overwrite_bi_graphemes_dictionary = args.config.getboolean(
        'train', 'overwrite_bi_graphemes_dictionary')
    max_duration = args.config.getfloat('data', 'max_duration')
    max_freq = args.config.getint('data', 'max_freq')
    language = args.config.get('data', 'language')

    log = LogUtil().getlogger()
    labelUtil = LabelUtil()

    # test_json = "resources/d.json"
    datagen = DataGenerator(save_dir=save_dir,
                            model_name=model_name,
                            max_freq=max_freq)
    datagen.train_audio_paths = [wav_file]
    datagen.train_durations = [get_duration_wave(wav_file)]
    datagen.train_texts = ["1 1"]
    datagen.count = 1
    # datagen.load_train_data(test_json, max_duration=max_duration)
    labelutil = load_labelutil(labelUtil, is_bi_graphemes, language="zh")
    args.config.set('arch', 'n_classes', str(labelUtil.get_count()))
    datagen.get_meta_from_file(
        np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
        np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

    is_batchnorm = args.config.getboolean('arch', 'is_batchnorm')
    if batch_size == 1 and is_batchnorm and (mode == 'train'
                                             or mode == 'load'):
        raise Warning('batch size 1 is too small for is_batchnorm')

    max_t_count = datagen.get_max_seq_length(partition="test")
    max_label_length = \
        datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes)

    args.config.set('arch', 'max_t_count', str(max_t_count))
    args.config.set('arch', 'max_label_length', str(max_label_length))
    from importlib import import_module
    prepare_data_template = import_module(args.config.get('arch', 'arch_file'))
    init_states = prepare_data_template.prepare_data(args)
    sort_by_duration = (mode == "train")
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')
    save_feature_as_csvfile = args.config.getboolean(
        'train', 'save_feature_as_csvfile')
    if is_bucketing:
        buckets = json.loads(args.config.get('arch', 'buckets'))
        data_loaded = BucketSTTIter(
            partition="train",
            count=datagen.count,
            datagen=datagen,
            batch_size=batch_size,
            num_label=max_label_length,
            init_states=init_states,
            seq_length=max_t_count,
            width=whcs.width,
            height=whcs.height,
            sort_by_duration=sort_by_duration,
            is_bi_graphemes=is_bi_graphemes,
            buckets=buckets,
            save_feature_as_csvfile=save_feature_as_csvfile)
    else:
        data_loaded = STTIter(partition="train",
                              count=datagen.count,
                              datagen=datagen,
                              batch_size=batch_size,
                              num_label=max_label_length,
                              init_states=init_states,
                              seq_length=max_t_count,
                              width=whcs.width,
                              height=whcs.height,
                              sort_by_duration=sort_by_duration,
                              is_bi_graphemes=is_bi_graphemes,
                              save_feature_as_csvfile=save_feature_as_csvfile)

    return data_loaded, args
Exemple #18
0
    def update(self, labels, preds):
        check_label_shapes(labels, preds)
        if self.is_logging:
            log = LogUtil().getlogger()
            labelUtil = LabelUtil()
        self.batch_loss = 0.
        shouldPrint = True
        host_name = socket.gethostname()
        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()
            seq_length = len(pred) / int(
                int(self.batch_size) / int(self.num_gpu))
            # sess = tf.Session()
            for i in range(int(int(self.batch_size) / int(self.num_gpu))):
                l = remove_blank(label[i])
                # p = []
                probs = []
                for k in range(int(seq_length)):
                    # p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i]))
                    probs.append(
                        pred[k * int(int(self.batch_size) / int(self.num_gpu))
                             + i])
                # p = pred_best(p)
                probs = np.array(probs)
                st = time.time()
                beam_size = 20
                results = ctc_beam_decode(self.scorer, beam_size,
                                          labelUtil.byList, probs)
                log.info("decode by ctc_beam cost %.2f result: %s" %
                         (time.time() - st, "\n".join(results)))

                res_str1 = ctc_greedy_decode(probs, labelUtil.byList)
                log.info("decode by pred_best: %s" % res_str1)

                # max_time_steps = int(seq_length)
                # input_log_prob_matrix_0 = np.log(probs)  # + 2.0
                #
                # # len max_time_steps array of batch_size x depth matrices
                # inputs = ([
                #   input_log_prob_matrix_0[t, :][np.newaxis, :] for t in range(max_time_steps)]
                # )
                #
                # inputs_t = [ops.convert_to_tensor(x) for x in inputs]
                # inputs_t = array_ops.stack(inputs_t)
                #
                # st = time.time()
                # # run CTC beam search decoder in tensorflow
                # decoded, log_probabilities = tf.nn.ctc_beam_search_decoder(inputs_t,
                #                                                            [max_time_steps],
                #                                                            beam_width=10,
                #                                                            top_paths=3,
                #                                                            merge_repeated=False)
                # tf_decoded, tf_log_probs = sess.run([decoded, log_probabilities])
                # st1 = time.time() - st
                # for index in range(3):
                #   tf_result = ''.join([labelUtil.byIndex.get(i + 1, ' ') for i in tf_decoded[index].values])
                #   print("%.2f elpse %.2f, %s" % (tf_log_probs[0][index], st1, tf_result))
                l_distance = editdistance.eval(
                    labelUtil.convert_num_to_word(l).split(" "), res_str1)
                # l_distance_beam = editdistance.eval(labelUtil.convert_num_to_word(l).split(" "), beam_result[0][1])
                l_distance_beam_cpp = editdistance.eval(
                    labelUtil.convert_num_to_word(l).split(" "), results[0])
                self.total_n_label += len(l)
                # self.total_l_dist_beam += l_distance_beam
                self.total_l_dist_beam_cpp += l_distance_beam_cpp
                self.total_l_dist += l_distance
                this_cer = float(l_distance) / float(len(l))
                if self.is_logging:
                    # log.info("%s label: %s " % (host_name, labelUtil.convert_num_to_word(l)))
                    # log.info("%s pred : %s , cer: %f (distance: %d/ label length: %d)" % (
                    #     host_name, labelUtil.convert_num_to_word(p), this_cer, l_distance, len(l)))
                    log.info("%s label: %s " %
                             (host_name, labelUtil.convert_num_to_word(l)))
                    log.info(
                        "%s pred : %s , cer: %f (distance: %d/ label length: %d)"
                        % (host_name, res_str1, this_cer, l_distance, len(l)))
                    # log.info("%s predb: %s , cer: %f (distance: %d/ label length: %d)" % (
                    #     host_name, " ".join(beam_result[0][1]), float(l_distance_beam) / len(l), l_distance_beam,
                    #     len(l)))
                    log.info(
                        "%s predc: %s , cer: %f (distance: %d/ label length: %d)"
                        % (host_name, " ".join(
                            results[0]), float(l_distance_beam_cpp) / len(l),
                           l_distance_beam_cpp, len(l)))
                self.total_ctc_loss += self.batch_loss
                self.placeholder = res_str1 + "\n" + "\n".join(results)
Exemple #19
0
    mx.random.seed(hash(datetime.now()))
    # set parameters from cfg file
    args = parse_args(sys.argv[1])

    log_filename = args.config.get('common', 'log_filename')
    log = LogUtil(filename=log_filename).getlogger()

    # set parameters from data section(common)
    mode = args.config.get('common', 'mode')
    if mode not in ['train', 'predict', 'load']:
        raise Exception(
            'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.')

    # get meta file where character to number conversions are defined
    language = args.config.get('data', 'language')
    labelUtil = LabelUtil.getInstance()
    if language == "en":
        labelUtil.load_unicode_set("resources/unicodemap_en_baidu.csv")
    else:
        raise Exception("Error: Language Type: %s" % language)
    args.config.set('arch', 'n_classes', str(labelUtil.get_count()))

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    batch_size = args.config.getint('common', 'batch_size')

    # check the number of gpus is positive divisor of the batch size
    if batch_size % num_gpu != 0:
        raise Exception('num_gpu should be positive divisor of batch_size')

    if mode == "predict":
Exemple #20
0
    def __init__(self):
        if len(sys.argv) <= 1:
            raise Exception('cfg file path must be provided. ' +
                            'ex)python main.py --configfile examplecfg.cfg')
        self.args = parse_args(sys.argv[1])
        # set parameters from cfg file
        # give random seed
        self.random_seed = self.args.config.getint('common', 'random_seed')
        self.mx_random_seed = self.args.config.getint('common',
                                                      'mx_random_seed')
        # random seed for shuffling data list
        if self.random_seed != -1:
            np.random.seed(self.random_seed)
        # set mx.random.seed to give seed for parameter initialization
        if self.mx_random_seed != -1:
            mx.random.seed(self.mx_random_seed)
        else:
            mx.random.seed(hash(datetime.now()))
        # set log file name
        self.log_filename = self.args.config.get('common', 'log_filename')
        self.log = LogUtil(filename=self.log_filename).getlogger()

        # set parameters from data section(common)
        self.mode = self.args.config.get('common', 'mode')

        save_dir = 'checkpoints'
        model_name = self.args.config.get('common', 'prefix')
        max_freq = self.args.config.getint('data', 'max_freq')
        self.datagen = DataGenerator(save_dir=save_dir,
                                     model_name=model_name,
                                     max_freq=max_freq)
        self.datagen.get_meta_from_file(
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
            np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))

        self.buckets = json.loads(self.args.config.get('arch', 'buckets'))

        # get meta file where character to number conversions are defined

        self.contexts = parse_contexts(self.args)
        self.num_gpu = len(self.contexts)
        self.batch_size = self.args.config.getint('common', 'batch_size')
        # check the number of gpus is positive divisor of the batch size for data parallel
        self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm')
        self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing')

        # log current config
        self.config_logger = ConfigLogger(self.log)
        self.config_logger(self.args.config)

        default_bucket_key = 1600
        self.args.config.set('arch', 'max_t_count', str(default_bucket_key))
        self.args.config.set('arch', 'max_label_length', str(95))
        self.labelUtil = LabelUtil()
        is_bi_graphemes = self.args.config.getboolean('common',
                                                      'is_bi_graphemes')
        load_labelutil(self.labelUtil, is_bi_graphemes, language="zh")
        self.args.config.set('arch', 'n_classes',
                             str(self.labelUtil.get_count()))
        self.max_t_count = self.args.config.getint('arch', 'max_t_count')
        # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states')

        # load model
        self.model_loaded, self.model_num_epoch, self.model_path = load_model(
            self.args)

        # self.model = STTBucketingModule(
        #     sym_gen=self.model_loaded,
        #     default_bucket_key=default_bucket_key,
        #     context=self.contexts
        # )

        from importlib import import_module
        prepare_data_template = import_module(
            self.args.config.get('arch', 'arch_file'))
        init_states = prepare_data_template.prepare_data(self.args)
        width = self.args.config.getint('data', 'width')
        height = self.args.config.getint('data', 'height')
        for bucket in self.buckets:
            net, init_state_names, ll = self.model_loaded(bucket)
            net.save('checkpoints/%s-symbol.json' % bucket)
        input_shapes = dict([('data',
                              (self.batch_size, default_bucket_key,
                               width * height))] + init_states + [('label',
                                                                   (1, 18))])
        # self.executor = net.simple_bind(ctx=mx.cpu(), **input_shapes)

        # self.model.bind(data_shapes=[('data', (self.batch_size, default_bucket_key, width * height))] + init_states,
        #                 label_shapes=[
        #                     ('label', (self.batch_size, self.args.config.getint('arch', 'max_label_length')))],
        #                 for_training=True)

        symbol, self.arg_params, self.aux_params = mx.model.load_checkpoint(
            self.model_path, self.model_num_epoch)
        all_layers = symbol.get_internals()
        concat = all_layers['concat36457_output']
        sm = mx.sym.SoftmaxOutput(data=concat, name='softmax')
        self.executor = sm.simple_bind(ctx=mx.cpu(), **input_shapes)
        # self.model.set_params(self.arg_params, self.aux_params, allow_extra=True, allow_missing=True)

        for key in self.executor.arg_dict.keys():
            if key in self.arg_params:
                self.arg_params[key].copyto(self.executor.arg_dict[key])
        init_state_names.remove('data')
        init_state_names.sort()
        self.states_dict = dict(
            zip(init_state_names, self.executor.outputs[1:]))
        self.input_arr = mx.nd.zeros(
            (self.batch_size, default_bucket_key, width * height))

        try:
            from swig_wrapper import Scorer

            vocab_list = [
                chars.encode("utf-8") for chars in self.labelUtil.byList
            ]
            self.log.info("vacab_list len is %d" % len(vocab_list))
            _ext_scorer = Scorer(0.26, 0.1,
                                 self.args.config.get('common', 'kenlm'),
                                 vocab_list)
            lm_char_based = _ext_scorer.is_character_based()
            lm_max_order = _ext_scorer.get_max_order()
            lm_dict_size = _ext_scorer.get_dict_size()
            self.log.info("language model: "
                          "is_character_based = %d," % lm_char_based +
                          " max_order = %d," % lm_max_order +
                          " dict_size = %d" % lm_dict_size)
            self.eval_metric = EvalSTTMetric(batch_size=self.batch_size,
                                             num_gpu=self.num_gpu,
                                             is_logging=True,
                                             scorer=_ext_scorer)
        except ImportError:
            import kenlm
            km = kenlm.Model(self.args.config.get('common', 'kenlm'))
            self.eval_metric = EvalSTTMetric(batch_size=self.batch_size,
                                             num_gpu=self.num_gpu,
                                             is_logging=True,
                                             scorer=km.score)