예제 #1
0
파일: align.py 프로젝트: tnellesen/DSAlign
def main():
    # Debug helpers
    logging.basicConfig()
    logging.root.setLevel(args.loglevel if args.loglevel else 20)

    def progress(it=None, desc='Processing', total=None):
        logging.info(desc)
        return it if args.no_progress else log_progress(
            it, interval=args.progress_interval, total=total)

    def resolve(base_path, spec_path):
        if spec_path is None:
            return None
        if not path.isabs(spec_path):
            spec_path = path.join(base_path, spec_path)
        return spec_path

    def exists(file_path):
        if file_path is None:
            return False
        return os.path.isfile(file_path)

    to_prepare = []

    def enqueue_or_fail(audio, tlog, script, aligned, prefix=''):
        if exists(aligned) and not args.force:
            fail(
                prefix +
                'Alignment file "{}" already existing - use --force to overwrite'
                .format(aligned))
        if tlog is None:
            if args.ignore_missing:
                return
            fail(prefix + 'Missing transcription log path')
        if not exists(audio) and not exists(tlog):
            if args.ignore_missing:
                return
            fail(prefix +
                 'Both audio file "{}" and transcription log "{}" are missing'.
                 format(audio, tlog))
        if not exists(script):
            if args.ignore_missing:
                return
            fail(prefix + 'Missing script "{}"'.format(script))
        to_prepare.append((audio, tlog, script, aligned))

    if (args.audio or
            args.tlog) and args.script and args.aligned and not args.catalog:
        enqueue_or_fail(args.audio, args.tlog, args.script, args.aligned)
    elif args.catalog:
        if not exists(args.catalog):
            fail('Unable to load catalog file "{}"'.format(args.catalog))
        catalog = path.abspath(args.catalog)
        catalog_dir = path.dirname(catalog)
        with open(catalog, 'r', encoding='utf-8') as catalog_file:
            catalog_entries = json.load(catalog_file)
        for entry in progress(catalog_entries, desc='Reading catalog'):
            enqueue_or_fail(
                resolve(catalog_dir, entry['audio']),
                resolve(catalog_dir, entry['tlog']),
                resolve(catalog_dir, entry['script']),
                resolve(catalog_dir, entry['aligned']),
                prefix='Problem loading catalog "{}" - '.format(catalog))
    else:
        fail(
            'You have to either specify a combination of "--audio/--tlog,--script,--aligned" or "--catalog"'
        )

    logging.debug('Start')

    to_align = []
    output_graph_path = None
    for audio_path, tlog_path, script_path, aligned_path in to_prepare:
        if not exists(tlog_path):
            generated_scorer = False
            if output_graph_path is None:
                logging.debug(
                    'Looking for model files in "{}"...'.format(model_dir))
                output_graph_path = glob(model_dir + "/*.pbmm")[0]
                lang_scorer_path = glob(model_dir + "/*.scorer")[0]
            kenlm_path = 'dependencies/kenlm/build/bin'
            if not path.exists(kenlm_path):
                kenlm_path = None
            deepspeech_path = 'dependencies/deepspeech'
            if not path.exists(deepspeech_path):
                deepspeech_path = None
            if kenlm_path and deepspeech_path and not args.stt_no_own_lm:
                tc = read_script(script_path)
                if not tc.clean_text.strip():
                    logging.error('Cleaned transcript is empty for {}'.format(
                        path.basename(script_path)))
                    continue
                clean_text_path = script_path + '.clean'
                with open(clean_text_path, 'w',
                          encoding='utf-8') as clean_text_file:
                    clean_text_file.write(tc.clean_text)

                scorer_path = script_path + '.scorer'
                if not path.exists(scorer_path):
                    # Generate LM
                    data_lower, vocab_str = convert_and_filter_topk(
                        scorer_path, clean_text_path, 500000)
                    build_lm(scorer_path, kenlm_path, 5, '85%', '0|0|1', True,
                             255, 8, 'trie', data_lower, vocab_str)
                    os.remove(scorer_path + '.' + 'lower.txt.gz')
                    os.remove(scorer_path + '.' + 'lm.arpa')
                    os.remove(scorer_path + '.' + 'lm_filtered.arpa')
                    os.remove(clean_text_path)

                    # Generate scorer
                    create_bundle(alphabet_path,
                                  scorer_path + '.' + 'lm.binary',
                                  scorer_path + '.' + 'vocab-500000.txt',
                                  scorer_path, False, 0.931289039105002,
                                  1.1834137581510284)
                    os.remove(scorer_path + '.' + 'lm.binary')
                    os.remove(scorer_path + '.' + 'vocab-500000.txt')

                    generated_scorer = True
            else:
                scorer_path = lang_scorer_path

            logging.debug(
                'Loading acoustic model from "{}", alphabet from "{}" and scorer from "{}"...'
                .format(output_graph_path, alphabet_path, scorer_path))

            # Run VAD on the input file
            logging.debug('Transcribing VAD segments...')
            frames = read_frames_from_file(audio_path, model_format,
                                           args.audio_vad_frame_length)
            segments = vad_split(frames,
                                 model_format,
                                 num_padding_frames=args.audio_vad_padding,
                                 threshold=args.audio_vad_threshold,
                                 aggressiveness=args.audio_vad_aggressiveness)

            def pre_filter():
                for i, segment in enumerate(segments):
                    segment_buffer, time_start, time_end = segment
                    time_length = time_end - time_start
                    if args.stt_min_duration and time_length < args.stt_min_duration:
                        logging.info(
                            'Fragment {}: Audio too short for STT'.format(i))
                        continue
                    if args.stt_max_duration and time_length > args.stt_max_duration:
                        logging.info(
                            'Fragment {}: Audio too long for STT'.format(i))
                        continue
                    yield (time_start, time_end,
                           np.frombuffer(segment_buffer, dtype=np.int16))

            samples = list(progress(pre_filter(), desc='VAD splitting'))

            pool = multiprocessing.Pool(initializer=init_stt,
                                        initargs=(output_graph_path,
                                                  scorer_path),
                                        processes=args.stt_workers)
            transcripts = list(
                progress(pool.imap(stt, samples),
                         desc='Transcribing',
                         total=len(samples)))

            fragments = []
            for time_start, time_end, segment_transcript in transcripts:
                if segment_transcript is None:
                    continue
                fragments.append({
                    'start': time_start,
                    'end': time_end,
                    'transcript': segment_transcript
                })
            logging.debug('Excluded {} empty transcripts'.format(
                len(transcripts) - len(fragments)))

            logging.debug(
                'Writing transcription log to file "{}"...'.format(tlog_path))
            with open(tlog_path, 'w', encoding='utf-8') as tlog_file:
                tlog_file.write(
                    json.dumps(fragments,
                               indent=4 if args.output_pretty else None,
                               ensure_ascii=False))

            # Remove scorer if generated
            if generated_scorer:
                os.remove(scorer_path)
        if not path.isfile(tlog_path):
            fail('Problem loading transcript from "{}"'.format(tlog_path))
        to_align.append((tlog_path, script_path, aligned_path))

    total_fragments = 0
    dropped_fragments = 0
    reasons = Counter()

    index = 0
    pool = multiprocessing.Pool(processes=args.align_workers)
    for aligned_file, file_total_fragments, file_dropped_fragments, file_reasons in \
            progress(pool.imap_unordered(align, to_align), desc='Aligning', total=len(to_align)):
        if args.no_progress:
            index += 1
            logging.info(
                'Aligned file {} of {} - wrote results to "{}"'.format(
                    index, len(to_align), aligned_file))
        total_fragments += file_total_fragments
        dropped_fragments += file_dropped_fragments
        reasons += file_reasons

    logging.info('Aligned {} fragments'.format(total_fragments))
    if total_fragments > 0 and dropped_fragments > 0:
        logging.info('Dropped {} fragments {:0.2f}%:'.format(
            dropped_fragments, dropped_fragments * 100.0 / total_fragments))
        for key, number in reasons.most_common():
            logging.info(' - {}: {}'.format(key, number))
예제 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-n',
        '--name',
        nargs=1,
        type=str,
        help='text, ki se bo procesiral. Ce je all, izberi vse',
        required=True)
    parser.add_argument('-m',
                        '--mode',
                        nargs=1,
                        type=str,
                        help='nacin izvajanja. "greedy" ali "lm+fst"',
                        required=True)
    args = parser.parse_args()
    sample_name = args.name[0]
    ctc_mode = args.mode[0]
    if sample_name != 'all':
        filenames = [sample_name]
    else:
        filenames = [
            name.name for name in Path('narecni_govor_z_besedili/').iterdir()
        ]
    print(filenames)
    for filename in filenames:
        data_dir = os.path.join('narecni_govor_z_besedili', filename)

        # _steps = {1, 2, 3, 4, 5}
        # _steps = {2, 5}
        # _steps = {1}
        _steps = {2, 3, 4, 5}
        print(filename)
        if 1 in _steps:
            # get nemo output
            audio_filename = os.path.join(data_dir, filename + '.wav')
            log_prob, log_prob_step_sec, vocab, segments = segment_and_asr(
                audio_filename)
            _k = log_prob_step_sec * 10**3
            frame_segments = [[int(e[0] / _k), int(e[1] / _k)]
                              for e in segments]
            json.dump(frame_segments,
                      open(os.path.join(data_dir, 'frame_segments.json'), 'w'),
                      indent=4)
            np.save(os.path.join(data_dir, filename + '.npy'), log_prob)

        if 2 in _steps:
            with open(os.path.join(data_dir, filename + '.txt'), 'r') as f:
                full_text = f.read()
            char_based_text = to_char_based_text(full_text)
            with open(os.path.join(data_dir, filename + '_char.txt'),
                      'w') as f:
                f.write(char_based_text)
            # generate lm
            args = argparse.Namespace()
            args.input_txt = os.path.join(data_dir, filename + '.txt')
            args.output_dir = data_dir
            args.top_k = 5000
            args.kenlm_bins = os.path.join('kenlm-master', 'build', 'bin')
            args.arpa_order = 5
            args.max_arpa_memory = "80%"
            args.binary_a_bits = 255
            args.binary_q_bits = 8
            args.binary_type = 'trie'
            args.discount_fallback = True

            data_lower, vocab_str = convert_and_filter_topk(args)
            with gzip.open(data_lower, 'rb') as f:
                origin_text = f.read().decode('utf8')
            originjson = [{
                'transcript': word,
                'start': 0,
                'end': 0,
                'index': i
            } for i, word in enumerate(origin_text.split())]
            origin_doublewords = [{
                'transcript': ' '.join(list(words)),
                'start': 0,
                'end': 0,
                'index': i
            } for i, words in enumerate(get_slices(origin_text.split(), 2))]
            origin_triplewords = [{
                'transcript': ' '.join(words),
                'start': 0,
                'end': 0,
                'index': i
            } for i, words in enumerate(get_slices(origin_text.split(), 3))]
            origin_quadruplewords = [{
                'transcript': ' '.join(words),
                'start': 0,
                'end': 0,
                'index': i
            } for i, words in enumerate(get_slices(origin_text.split(), 4))]
            origin_quintuplewords = [{
                'transcript': ' '.join(words),
                'start': 0,
                'end': 0,
                'index': i
            } for i, words in enumerate(get_slices(origin_text.split(), 5))]
            json.dump(originjson,
                      open(os.path.join(data_dir, filename + '.json'), 'w'),
                      indent=4,
                      ensure_ascii=False)
            json.dump(origin_doublewords,
                      open(os.path.join(data_dir, filename + '_double.json'),
                           'w'),
                      indent=4,
                      ensure_ascii=False)
            json.dump(origin_triplewords,
                      open(os.path.join(data_dir, filename + '_triple.json'),
                           'w'),
                      indent=4,
                      ensure_ascii=False)
            json.dump(origin_quadruplewords,
                      open(
                          os.path.join(data_dir, filename + '_quadruple.json'),
                          'w'),
                      indent=4,
                      ensure_ascii=False)
            json.dump(origin_quintuplewords,
                      open(
                          os.path.join(data_dir, filename + '_quintuple.json'),
                          'w'),
                      indent=4,
                      ensure_ascii=False)

            build_lm(args, data_lower, vocab_str)

        if 3 in _steps:
            # create fst
            out_file = os.path.join(data_dir, filename + '.fst')
            fst_vocab = os.path.join(data_dir, f'vocab-{args.top_k}.txt')
            token_file = 'labels.txt'
            build_fst_from_words_file(fst_vocab, token_file, out_file)

        if 4 in _steps:
            # decode
            # print(frame_segments)
            frame_segments = json.load(
                open(os.path.join(data_dir, 'frame_segments.json'), 'r'))
            char_ms = real_ctc_decode(filename,
                                      data_dir,
                                      frame_segments,
                                      mode=ctc_mode)
            np.savetxt(os.path.join(data_dir, 'char_ms.txt'), char_ms)

        if 5 in _steps:
            char_ms = np.loadtxt(os.path.join(data_dir, 'char_ms.txt'))
            copyfile(f'{data_dir}/{filename}_quintuple.json',
                     f'{data_dir}/origin_fragments.json')
            dropped_fragments = 1
            while dropped_fragments > 0:
                subprocess.check_call([
                    'DSAlign/bin/align.sh', '--script',
                    f'{data_dir}/{filename}_result.txt', '--tlog',
                    f'{data_dir}/origin_fragments.json', '--aligned',
                    f'{data_dir}/{filename}_aligned.json',
                    '--align-candidate-threshold', '1', '--alphabet',
                    'alphabet.txt', '--force', '--output-pretty'
                ])
                origin_fragments = json.load(
                    open(f'{data_dir}/origin_fragments.json', 'r'))
                n_fragments = len(origin_fragments)
                aligned_fragments = json.load(
                    open(f'{data_dir}/{filename}_aligned.json', 'r'))
                matched_ids = set([e['index'] for e in aligned_fragments])
                non_matched_ids = set([
                    idx for idx in range(n_fragments) if idx not in matched_ids
                ])
                dropped_fragments = n_fragments - len(aligned_fragments)
                if dropped_fragments > 0:
                    merges = []
                    for i in range(n_fragments):
                        if i not in non_matched_ids:
                            continue
                        if i == n_fragments - 1:
                            target = i - 1
                        else:
                            target = i + 1
                        merges.append((min(i, target), max(i, target)))
                        non_matched_ids.discard(i)
                        non_matched_ids.discard(target)
                    origin_unmerged = dict([(e['index'], e['transcript'])
                                            for e in origin_fragments])
                    for i, i_next in merges:
                        text_i_next = origin_unmerged[i_next]
                        text_i = origin_unmerged.pop(i)
                        origin_unmerged[i_next] = text_i.rstrip(
                        ) + ' ' + text_i_next.lstrip()
                    origin_merged = [{
                        'transcript': val,
                        'start': 0,
                        'end': 0,
                        'index': i
                    } for i, (k, val) in enumerate(origin_unmerged.items())]
                    json.dump(origin_merged,
                              open(f'{data_dir}/origin_fragments.json', 'w'),
                              ensure_ascii=False,
                              indent=4)

            aligned = json.load(
                open(os.path.join(data_dir, filename + '_aligned.json'), 'r'))
            alinged_true = []
            word_index = 0
            for el in aligned:
                words = el['transcript']
                l0 = len(words)
                l1 = len(el['aligned'])
                start = el['text-start']
                starts = [start]
                ends = []
                for x in re.finditer(' ', words):
                    _start = start + int(x.span()[0] / l0 * l1)
                    _end = start + int((x.span()[0] - 1) / l0 * l1)
                    starts.append(_start)
                    ends.append(_end)
                ends.append(start + l1 - 1)
                for i, word in enumerate(words.split()):
                    ms_start = char_ms[starts[i] if starts[i] < len(char_ms
                                                                    ) else -1]
                    ms_end = char_ms[ends[i] if ends[i] < len(char_ms) else -1]
                    if ms_start == ms_end:
                        ms_end += 0.01
                    alinged_true.append({
                        'text': word,
                        'start': float(ms_start),
                        'end': float(ms_end),
                        'index': word_index
                    })
                    word_index += 1
            json.dump(alinged_true,
                      open(os.path.join(data_dir, f'aligned_{ctc_mode}.json'),
                           'w'),
                      indent=4,
                      ensure_ascii=False)
예제 #3
0
def main():
    # Debug helpers
    logging.basicConfig()
    logging.root.setLevel(args.loglevel if args.loglevel else 20)

    def progress(it=None, desc="Processing", total=None):
        logging.info(desc)
        return (it if args.no_progress else log_progress(
            it, interval=args.progress_interval, total=total))

    def resolve(base_path, spec_path):
        if spec_path is None:
            return None
        if not path.isabs(spec_path):
            spec_path = path.join(base_path, spec_path)
        return spec_path

    def exists(file_path):
        if file_path is None:
            return False
        return os.path.isfile(file_path)

    to_prepare = []

    def enqueue_or_fail(audio, tlog, script, aligned, prefix=""):
        if exists(aligned) and not args.force:
            fail(
                prefix +
                'Alignment file "{}" already existing - use --force to overwrite'
                .format(aligned))
        if tlog is None:
            if args.ignore_missing:
                return
            fail(prefix + "Missing transcription log path")
        if not exists(audio) and not exists(tlog):
            if args.ignore_missing:
                return
            fail(prefix +
                 'Both audio file "{}" and transcription log "{}" are missing'.
                 format(audio, tlog))
        if not exists(script):
            if args.ignore_missing:
                return
            fail(prefix + 'Missing script "{}"'.format(script))
        to_prepare.append((audio, tlog, script, aligned))

    if (args.audio or
            args.tlog) and args.script and args.aligned and not args.catalog:
        enqueue_or_fail(args.audio, args.tlog, args.script, args.aligned)
    elif args.catalog:
        if not exists(args.catalog):
            fail('Unable to load catalog file "{}"'.format(args.catalog))
        catalog = path.abspath(args.catalog)
        catalog_dir = path.dirname(catalog)
        with open(catalog, "r", encoding="utf-8") as catalog_file:
            catalog_entries = json.load(catalog_file)
        for entry in progress(catalog_entries, desc="Reading catalog"):
            enqueue_or_fail(
                resolve(catalog_dir, entry["audio"]),
                resolve(catalog_dir, entry["tlog"]),
                resolve(catalog_dir, entry["script"]),
                resolve(catalog_dir, entry["aligned"]),
                prefix='Problem loading catalog "{}" - '.format(catalog),
            )
    else:
        fail(
            'You have to either specify a combination of "--audio/--tlog,--script,--aligned" or "--catalog"'
        )

    logging.debug("Start")

    to_align = []
    output_graph_path = None
    for audio_path, tlog_path, script_path, aligned_path in to_prepare:
        if not exists(tlog_path):  # or args.force:
            generated_scorer = False
            if output_graph_path is None:
                logging.debug(
                    'Looking for model files in "{}"...'.format(model_dir))
                output_graph_path = glob(model_dir + "/*.pbmm")[0]
                lang_scorer_path = glob(model_dir + "/*.scorer")[0]
            kenlm_path = "/install/kenlm/build/bin"
            deepspeech_path = "third_party/DeepSpeech"
            if args.per_document_lm:
                assert path.exists(kenlm_path)
                assert path.exists(deepspeech_path)

                scorer_path = script_path + ".scorer"
                if not path.exists(scorer_path):
                    data_lower, vocab_str = convert_and_filter_topk(
                        scorer_path, clean_text_path, 500000)
                    build_lm(
                        scorer_path,
                        kenlm_path,
                        5,
                        "85%",
                        "0|0|1",
                        True,
                        255,
                        8,
                        "trie",
                        data_lower,
                        vocab_str,
                    )
                    os.remove(scorer_path + "." + "lower.txt.gz")
                    os.remove(scorer_path + "." + "lm.arpa")
                    os.remove(scorer_path + "." + "lm_filtered.arpa")
                    os.remove(clean_text_path)

                    create_bundle(
                        alphabet_path,
                        scorer_path + "." + "lm.binary",
                        scorer_path + "." + "vocab-500000.txt",
                        scorer_path,
                        False,
                        0.931289039105002,
                        1.1834137581510284,
                    )
                    os.remove(scorer_path + "." + "lm.binary")
                    os.remove(scorer_path + "." + "vocab-500000.txt")

                tc = read_script(script_path)
                if not tc.clean_text.strip():
                    logging.error("Cleaned transcript is empty for {}".format(
                        path.basename(script_path)))
                    continue
                clean_text_path = script_path + ".clean"
                with open(clean_text_path, "w",
                          encoding="utf-8") as clean_text_file:
                    clean_text_file.write(tc.clean_text)

                generated_scorer = True
            else:
                scorer_path = lang_scorer_path

            logging.debug(
                'Loading acoustic model from "{}", alphabet from "{}" and scorer from "{}"...'
                .format(output_graph_path, alphabet_path, scorer_path))

            # Run VAD on the input file
            logging.debug("Transcribing VAD segments...")
            frames = read_frames_from_file(audio_path, model_format,
                                           args.audio_vad_frame_length)
            frames = list(frames)
            with open("dsalign_voiced_buffers.npy", "wb") as fh:
                np.save(fh, frames)
            segments = vad_split(
                frames,
                model_format,
                num_padding_frames=args.audio_vad_padding,
                threshold=args.audio_vad_threshold,
                aggressiveness=args.audio_vad_aggressiveness,
            )

            def pre_filter():
                for i, segment in enumerate(segments):
                    segment_buffer, time_start, time_end = segment
                    time_length = time_end - time_start
                    if args.stt_min_duration and time_length < args.stt_min_duration:
                        logging.info(
                            "Fragment {}: Audio too short for STT".format(i))
                        continue
                    if args.stt_max_duration and time_length > args.stt_max_duration:
                        logging.info(
                            "Fragment {}: Audio too long for STT".format(i))
                        continue
                    yield (
                        time_start,
                        time_end,
                        np.frombuffer(segment_buffer, dtype=np.int16),
                    )

            samples = list(progress(pre_filter(), desc="VAD splitting"))

            # It does multiprocessing on the individual chunks within
            # a particular document. This is not a great thing. Not
            # much parallelism were we to use a TPU or GPU.

            # multiprocessing pool. Need to replace this with a queue of some sort.
            pool = multiprocessing.Pool(
                initializer=init_stt,
                initargs=(output_graph_path, scorer_path),
                processes=args.stt_workers,
            )
            transcripts = list(
                progress(pool.imap(stt, samples),
                         desc="Transcribing",
                         total=len(samples)))

            fragments = []
            for time_start, time_end, segment_transcript in transcripts:
                if segment_transcript is None:
                    continue
                fragments.append({
                    "start": time_start,
                    "end": time_end,
                    "transcript": segment_transcript,
                })
            logging.debug("Excluded {} empty transcripts".format(
                len(transcripts) - len(fragments)))

            logging.debug(
                'Writing transcription log to file "{}"...'.format(tlog_path))
            with open(tlog_path, "w", encoding="utf-8") as tlog_file:
                tlog_file.write(
                    json.dumps(
                        fragments,
                        indent=4 if args.output_pretty else None,
                        ensure_ascii=False,
                    ))

            # Remove scorer if generated
            if generated_scorer:
                os.remove(scorer_path)
        if not path.isfile(tlog_path):
            fail('Problem loading transcript from "{}"'.format(tlog_path))
        to_align.append((tlog_path, script_path, aligned_path))

    total_fragments = 0
    dropped_fragments = 0
    reasons = Counter()

    index = 0
    pool = multiprocessing.Pool(processes=args.align_workers)
    for (
            aligned_file,
            file_total_fragments,
            file_dropped_fragments,
            file_reasons,
    ) in progress(pool.imap_unordered(align, to_align),
                  desc="Aligning",
                  total=len(to_align)):
        if args.no_progress:
            index += 1
            logging.info(
                'Aligned file {} of {} - wrote results to "{}"'.format(
                    index, len(to_align), aligned_file))
        total_fragments += file_total_fragments
        dropped_fragments += file_dropped_fragments
        reasons += file_reasons

    logging.info("Aligned {} fragments".format(total_fragments))
    if total_fragments > 0 and dropped_fragments > 0:
        logging.info("Dropped {} fragments {:0.2f}%:".format(
            dropped_fragments, dropped_fragments * 100.0 / total_fragments))
        for key, number in reasons.most_common():
            logging.info(" - {}: {}".format(key, number))