示例#1
0
文件: utils.py 项目: zw76859420/ASR-1
    def get_test_batches_with_buckets_and_target(self, src_path, dst_path, tokens_per_batch):
        buckets = [(i) for i in range(50, 10000, 5)]

        def select_bucket(sl):
            for l1 in buckets:
                if sl < l1:
                    return l1
            raise Exception("The sequence is too long: ({})".format(sl))

        uttid_target_map = {}
        for line in codecs.open(dst_path, 'r', 'utf-8'):
            line = line.strip()
            if line == '' or line is None:
                continue
            splits = re.split('\s+', line)
            uttid = splits[0].strip()
            target = splits[1:]
            uttid_target_map[uttid] = target
        logging.info('loaded dst_path=' + str(dst_path) + ',size=' + str(len(uttid_target_map)))

        caches = {}
        for bucket in buckets:
            caches[bucket] = [[], [], 0, 0]

        scp_reader = zark.ArkReader(src_path)
        count = 0
        while True:
            uttid, input, loop = scp_reader.read_next_utt()
            if loop:
                break

            target = uttid_target_map[uttid]
            if target is None:
                logging.warn('uttid=' + str(uttid) + ',target is None')
                continue

            input_len = len(input)
            target_len = len(target)
            bucket = select_bucket(input_len)
            caches[bucket][0].append(input)
            caches[bucket][1].append(target)
            caches[bucket][2] += input_len
            caches[bucket][3] += target_len
            count = count + 1
            if caches[bucket][2] > tokens_per_batch:
                feat_batch, feat_batch_mask = self._create_feat_batch(caches[bucket][0])
                target_batch, target_batch_mask = self._create_target_batch(caches[bucket][1], self.dst2idx)
                yield feat_batch, target_batch
                caches[bucket] = [[], [], 0, 0]

        for bucket in buckets:
            if len(caches[bucket][0]) > 0:
                logging.info('get_test_batches_with_buckets, len(caches[bucket][0])=' + str(len(caches[bucket][0])))
                feat_batch, feat_batch_mask = self._create_feat_batch(caches[bucket][0])
                target_batch, target_batch_mask = self._create_target_batch(caches[bucket][1], self.dst2idx)
                yield feat_batch, target_batch

        logging.info('get_test_batches_with_buckets, loaded count=' + str(count))
示例#2
0
    def get_test_batches_with_buckets(self, src_path, tokens_per_batch):
        buckets = [(i) for i in range(50, 10000, 10)]

        def select_bucket(sl):
            for l1 in buckets:
                if sl < l1:
                    return l1
            raise Exception("The sequence is too long: ({})".format(sl))

        caches = {}
        for bucket in buckets:
            caches[bucket] = [[], [], 0]  # feats, uttids, count

        scp_reader = zark.ArkReader(src_path)
        count = 0
        while True:
            uttid, input, loop = scp_reader.read_next_utt()
            if loop:
                break

            input_len = len(input)
            bucket = select_bucket(input_len)
            caches[bucket][0].append(input)
            caches[bucket][1].append(uttid)
            caches[bucket][2] += input_len
            count = count + 1
            if caches[bucket][2] > tokens_per_batch:
                feat_batch, feat_batch_mask = self._create_feat_batch(
                    caches[bucket][0])
                yield feat_batch, caches[bucket][1]
                caches[bucket] = [[], [], 0]

        # Clean remain sentences.
        for bucket in buckets:
            if len(caches[bucket][0]) > 0:
                logging.info(
                    'get_test_batches_with_buckets, len(caches[bucket][0])=' +
                    str(len(caches[bucket][0])))
                feat_batch, feat_batch_mask = self._create_feat_batch(
                    caches[bucket][0])
                yield feat_batch, caches[bucket][1]

        logging.info('get_test_batches_with_buckets, loaded count=' +
                     str(count))
示例#3
0
 def get_test_batches(self, src_path, batch_size):
     scp_reader = zark.ArkReader(src_path)
     cache = []
     uttids = []
     while True:
         uttid, feat, loop = scp_reader.read_next_utt()
         if loop:
             break
         cache.append(feat)
         uttids.append(uttid)
         if len(cache) >= batch_size:
             feat_batch, feat_batch_mask = self._create_feat_batch(cache)
             # yield feat_batch, feat_batch_mask, uttids
             yield feat_batch, uttids
             cache = []
             uttids = []
     if cache:
         feat_batch, feat_batch_mask = self._create_feat_batch(cache)
         # yield feat_batch, feat_batch_mask, uttids
         yield feat_batch, uttids
示例#4
0
    def get_training_batches_with_buckets(self, shuffle=True):
        """
        Generate batches according to bucket setting.
        """

        # buckets = [(i, i) for i in range(5, 1000000, 3)]
        buckets = [
            (i, i)
            for i in range(self._config.bucket_min, self._config.bucket_max,
                           self._config.bucket_step)
        ]

        def select_bucket(sl, dl):
            for l1, l2 in buckets:
                if sl < l1 and dl < l2:
                    return l1, l2
            raise Exception("The sequence is too long: ({}, {})".format(
                sl, dl))

        # Shuffle the training files.
        src_path = self._config.train.src_path
        dst_path = self._config.train.dst_path
        max_length = self._config.train.max_length

        if shuffle:
            logging.info('Shuffle files %s and %s.' % (src_path, dst_path))
            src_shuf_path, dst_shuf_path = self.shuffle([src_path, dst_path],
                                                        self._config.model_dir)
            logging.info('Shuffled files %s and %s.' %
                         (src_shuf_path, dst_shuf_path))
            self._tmps.add(src_shuf_path)
            self._tmps.add(dst_shuf_path)
        else:
            src_shuf_path = src_path
            dst_shuf_path = dst_path

        caches = {}
        for bucket in buckets:
            caches[bucket] = [
                [], [], 0, 0
            ]  # src sentences, dst sentences, src tokens, dst tokens

        uttid_target_map = {}
        for line in codecs.open(dst_shuf_path, 'r', 'utf-8'):
            line = line.strip()
            if line == '' or line is None:
                continue
            splits = re.split('\s+', line)
            uttid = splits[0].strip()
            target = splits[1:]
            uttid_target_map[uttid] = target
        logging.info('loaded dst_shuf_path=' + str(dst_shuf_path) + ',size=' +
                     str(len(uttid_target_map)))

        num_random_caches = 5000
        num_cache_max_length = 600
        num_cache_min_length = 100
        num_cache_target_min_length = 4
        random_caches = []
        count = 0
        scp_reader = zark.ArkReader(src_shuf_path)
        while True:
            uttid, input, looped = scp_reader.read_next_utt()
            if looped:
                break

            target = uttid_target_map[uttid]
            if target is None:
                logging.warn('uttid=' + str(uttid) + ',target is None')
                continue

            input_len = len(input)
            target_len = len(target)
            if input_len > max_length or target_len > max_length:
                logging.warn('uttid=' + str(uttid) + ',input_len=' +
                             str(input_len) + ' > max_length=' +
                             str(max_length))
                continue

            count = count + 1
            if target_len == 0:
                continue

            bucket = select_bucket(input_len, target_len)
            caches[bucket][0].append(input)
            caches[bucket][1].append(target)
            caches[bucket][2] += input_len
            caches[bucket][3] += target_len

            if len(random_caches) < num_random_caches and num_cache_min_length <= input_len <= num_cache_max_length \
                    and target_len >= num_cache_target_min_length:
                random_caches.append([input, target])

            if max(caches[bucket][2],
                   caches[bucket][3]) > self._config.train.tokens_per_batch:
                feat_batch, feat_batch_mask = self._create_feat_batch(
                    caches[bucket][0])
                target_batch, target_batch_mask = self._create_target_batch(
                    caches[bucket][1], self.dst2idx)
                # yield (feat_batch, feat_batch_mask, target_batch, target_batch_mask)
                yield (feat_batch, target_batch, len(caches[bucket][0]))
                caches[bucket] = [[], [], 0, 0]

        # Clean remain sentences.
        for bucket in buckets:
            # Ensure each device at least get one sample.
            if len(caches[bucket][0]) > 0:
                src_len = len(caches[bucket][0])
                if self._config.min_count_in_bucket is None:
                    default_min_count_in_bucket = 20
                    logging.info('min_count_in_bucket=' +
                                 str(self._config.min_count_in_bucket) +
                                 ',use default_min_count_in_bucket=' +
                                 str(default_min_count_in_bucket))
                    self._config.min_count_in_bucket = default_min_count_in_bucket
                left_count = self._config.min_count_in_bucket - src_len
                if left_count > 0:  # append to self._config.train.num_gpus
                    for idx in range(left_count):
                        rand_idx = np.random.randint(0, num_random_caches)
                        if rand_idx >= num_random_caches:
                            rand_idx = 0
                        input, target = random_caches[rand_idx]
                        caches[bucket][0].append(input)
                        caches[bucket][1].append(target)
                        caches[bucket][2] += len(input)
                        caches[bucket][3] += len(target)
                    dst_len = len(caches[bucket][0])
                    logging.info(
                        'get_training_batches_with_buckets, src_len=' +
                        str(src_len) + ',dst_len=' + str(dst_len) +
                        ',bucket=' + str(bucket) +
                        ',max(caches[bucket][2], caches[bucket][3])=' +
                        str(max(caches[bucket][2], caches[bucket][3])))
                feat_batch, feat_batch_mask = self._create_feat_batch(
                    caches[bucket][0])
                target_batch, target_batch_mask = self._create_target_batch(
                    caches[bucket][1], self.dst2idx)
                # yield (feat_batch, feat_batch_mask, target_batch, target_batch_mask)
                yield (feat_batch, target_batch, len(caches[bucket][0]))

        logging.info('loaded count=' + str(count))
        # Remove shuffled files when epoch finished.
        if shuffle:
            os.remove(src_shuf_path)
            os.remove(dst_shuf_path)
            self._tmps.remove(src_shuf_path)
            self._tmps.remove(dst_shuf_path)