예제 #1
0
def create_bundle(
    alphabet_path,
    lm_path,
    vocab_path,
    package_path,
    force_utf8,
    default_alpha,
    default_beta,
):
    words = set()
    vocab_looks_char_based = True
    with open(vocab_path) as fin:
        for line in fin:
            for word in line.split():
                words.add(word.encode("utf-8"))
                if len(word) > 1:
                    vocab_looks_char_based = False
    print("{} unique words read from vocabulary file.".format(len(words)))

    cbm = "Looks" if vocab_looks_char_based else "Doesn't look"
    print("{} like a character based model.".format(cbm))

    if force_utf8 != None:  # pylint: disable=singleton-comparison
        use_utf8 = force_utf8.value
    else:
        use_utf8 = vocab_looks_char_based
        print("Using detected UTF-8 mode: {}".format(use_utf8))

    if use_utf8:
        serialized_alphabet = UTF8Alphabet().serialize()
    else:
        if not alphabet_path:
            raise RuntimeError("No --alphabet path specified, can't continue.")
        serialized_alphabet = Alphabet(alphabet_path).serialize()

    alphabet = NativeAlphabet()
    err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet))
    if err != 0:
        raise RuntimeError("Error loading alphabet: {}".format(err))

    scorer = Scorer()
    scorer.set_alphabet(alphabet)
    scorer.set_utf8_mode(use_utf8)
    scorer.reset_params(default_alpha, default_beta)
    err = scorer.load_lm(lm_path)
    if err != ds_ctcdecoder.DS_ERR_SCORER_NO_TRIE:
        print('Error loading language model file: 0x{:X}.'.format(err))
        print(
            'See the error codes section in https://deepspeech.readthedocs.io for a description.'
        )
        sys.exit(1)
    scorer.fill_dictionary(list(words))
    shutil.copy(lm_path, package_path)
    # append, not overwrite
    if scorer.save_dictionary(package_path, True):
        print("Package created in {}".format(package_path))
    else:
        print("Error when creating {}".format(package_path))
        sys.exit(1)
    def __init__(self,
                 corpus_name,
                 archive_url,
                 extract_dir=None,
                 output_path=None,
                 data_dir=None,
                 csv_append_mode=False,
                 filter_alphabet=None):
        self.corpus_name = corpus_name
        self.archive_url = archive_url
        # Make archive_name from archive_filename
        self.archive_filename = self.archive_url.rsplit('/', 1)[-1]
        # os.path.splitext:
        # tar.gz: will split ("file.name.tar",".gz")
        # but will split correctly "cnz_1.0.0.zip" into ("cnz_1.0.0","zip")
        self.archive_name = os.path.splitext(self.archive_filename)[0]
        if self.archive_name.endswith(".tar"):
            self.archive_name = self.archive_name.replace(".tar", "")
        self.extract_dir = self.archive_name
        if extract_dir is not None:
            self.extract_dir = extract_dir
        # Making path absolute root data or prefered from param data_dir
        self.dataset_path = os.path.abspath(
            self.corpus_name) if data_dir == None else os.path.join(
                data_dir, self.corpus_name)

        self.origin_data_path = os.path.join(
            self.dataset_path, "origin") if data_dir == None else data_dir

        if (output_path == None):
            #default
            importers_output_dir = os.path.abspath(BASE_OUTPUT_FOLDER_NAME)
            if not path.exists(importers_output_dir):
                print('No path "%s" - creating ...' % importers_output_dir)
                makedirs(importers_output_dir)

            self.dataset_output_path = os.path.join(importers_output_dir,
                                                    self.corpus_name)
        else:
            ##exernal dir
            self.dataset_output_path = os.path.join(output_path,
                                                    self.corpus_name)

        self.csv_append_mode = csv_append_mode
        self.filter_max_secs = MAX_SECS  ##filter for single clips max duration in second
        self.filter_min_secs = MIN_SECS  ##filter for single clips min duration in second

        if (data_dir != None):
            self.csv_wav_absolute_path = True
        else:
            ##default
            ##relative path from importers_output
            self.csv_wav_absolute_path = False

        ########################
        self.ALPHABET = Alphabet(
            filter_alphabet) if filter_alphabet != None else None
예제 #3
0
def create_bundle(
    alphabet_path,
    lm_path,
    vocab_path,
    package_path,
    force_utf8,
    default_alpha,
    default_beta,
):
    words = set()
    vocab_looks_char_based = True
    with open(vocab_path) as fin:
        for line in fin:
            for word in line.split():
                words.add(word.encode("utf-8"))
                if len(word) > 1:
                    vocab_looks_char_based = False
    print("{} unique words read from vocabulary file.".format(len(words)))
    print("{} like a character based model.".format(
        "Looks" if vocab_looks_char_based else "Doesn't look"))

    if force_utf8 != None:  # pylint: disable=singleton-comparison
        use_utf8 = force_utf8.value
        print("Forcing UTF-8 mode = {}".format(use_utf8))
    else:
        use_utf8 = vocab_looks_char_based

    if use_utf8:
        serialized_alphabet = UTF8Alphabet().serialize()
    else:
        if not alphabet_path:
            print("No --alphabet path specified, can't continue.")
            sys.exit(1)
        serialized_alphabet = Alphabet(alphabet_path).serialize()

    alphabet = NativeAlphabet()
    err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet))
    if err != 0:
        print("Error loading alphabet: {}".format(err))
        sys.exit(1)

    scorer = Scorer()
    scorer.set_alphabet(alphabet)
    scorer.set_utf8_mode(use_utf8)
    scorer.reset_params(default_alpha, default_beta)
    scorer.load_lm(lm_path)
    scorer.fill_dictionary(list(words))
    shutil.copy(lm_path, package_path)
    scorer.save_dictionary(package_path, True)  # append, not overwrite
    print("Package created in {}".format(package_path))
예제 #4
0
def get_alphabet(language):
    if language in ALPHABETS:
        return ALPHABETS[language]
    alphabet_path = getattr(CLI_ARGS, language + "_alphabet")
    alphabet = Alphabet(alphabet_path) if alphabet_path else None
    ALPHABETS[language] = alphabet
    return alphabet
예제 #5
0
 def _ending_tester(self, file, expected):
     alphabet = Alphabet(
         os.path.join(os.path.dirname(__file__), 'test_data', file))
     label = ''
     label_id = -1
     for expected_label, expected_label_id in expected:
         try:
             label_id = alphabet.Encode(expected_label)
         except KeyError:
             pass
         self.assertEqual(label_id, [expected_label_id])
         try:
             label = alphabet.Decode([expected_label_id])
         except KeyError:
             pass
         self.assertEqual(label, expected_label)
예제 #6
0
def init_worker(params):
    global FILTER_OBJ  # pylint: disable=global-statement
    global AUDIO_DIR  # pylint: disable=global-statement
    AUDIO_DIR = params.audio_dir if params.audio_dir else os.path.join(
        params.tsv_dir, "clips")
    validate_label = get_validate_label(params)
    alphabet = Alphabet(
        params.filter_alphabet) if params.filter_alphabet else None
    FILTER_OBJ = LabelFilter(params.normalize, alphabet, validate_label)
예제 #7
0
def create_bundle(
    alphabet_path,
    lm_path,
    vocab_path,
    package_path,
    force_utf8,
    default_alpha,
    default_beta,
):
    words = set()
    with open(vocab_path) as fin:
        for line in fin:
            for word in line.split():
                words.add(word.encode("utf-8"))

    if not alphabet_path:
        raise RuntimeError("No --alphabet path specified, can't continue.")
    serialized_alphabet = Alphabet(alphabet_path).serialize()

    alphabet = NativeAlphabet()
    err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet))
    if err != 0:
        raise RuntimeError("Error loading alphabet: {}".format(err))

    scorer = Scorer()
    scorer.set_alphabet(alphabet)
    scorer.reset_params(default_alpha, default_beta)
    scorer.load_lm(lm_path)
    # TODO: Why is this not working?
    # err = scorer.load_lm(lm_path)
    # if err != ds_ctcdecoder.DS_ERR_SCORER_NO_TRIE:
    #    print('Error loading language model file: 0x{:X}.'.format(err))
    #    print('See the error codes section in https://deepspeech.readthedocs.io for a description.')
    #    sys.exit(1)
    scorer.fill_dictionary(list(words))
    shutil.copy(lm_path, package_path)
    # This is the problem!
    scorer.save_dictionary(package_path, True)  # append, not overwrite
    print("Package created in {}".format(package_path))
 def __init__(self, time_step, feature_size, hidden_size, output_size,
              num_rnn_layers):
     super(IAMModel, self).__init__()
     self.cnn = CNN(time_step=time_step)
     self.rnn = RNN(feature_size=feature_size,
                    hidden_size=hidden_size,
                    output_size=output_size,
                    num_layers=num_rnn_layers)
     self.time_step = time_step
     self.alphabet = Alphabet(os.path.abspath("chars.txt"))
     self.scorer = Scorer(alphabet=self.alphabet,
                          scorer_path='iam_uncased.scorer',
                          alpha=0.75,
                          beta=1.85)
 def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats):
     super(SpeechRecognitionModel, self).__init__()
     n_feats = n_feats // 2
     self.cnn = nn.Conv2d(1, 32, 3, stride=2, padding=1)
     self.res_cnn = nn.Sequential(*[
         ResidualCNN(32, 32, kernel=3, n_feats=n_feats)
         for _ in range(n_cnn_layers)
     ])
     self.fc = nn.Linear(n_feats * 32, rnn_dim)
     self.rnn = nn.Sequential(*[
         RNN(rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
             hidden_size=rnn_dim,
             batch_first=i == 0) for i in range(n_rnn_layers)
     ])
     self.dense = nn.Sequential(nn.Linear(rnn_dim * 2, rnn_dim), nn.GELU(),
                                nn.Dropout(0.2),
                                nn.Linear(rnn_dim, n_class))
     self.alphabet = Alphabet(os.path.abspath("chars.txt"))
     self.scorer = Scorer(alphabet=self.alphabet,
                          scorer_path='librispeech.scorer',
                          alpha=0.75,
                          beta=1.85)
예제 #10
0
if __name__ == "__main__":
    PARSER = get_importers_parser(
        description="Import XML from Conference Centre for Economics, France")
    PARSER.add_argument("target_dir", help="Destination directory")
    PARSER.add_argument(
        "--filter_alphabet",
        help="Exclude samples with characters not in provided alphabet")
    PARSER.add_argument(
        "--normalize",
        action="store_true",
        help="Converts diacritic characters to their base ones")

    PARAMS = PARSER.parse_args()
    validate_label = get_validate_label(PARAMS)
    ALPHABET = Alphabet(
        PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None

    def label_filter_fun(label):
        if PARAMS.normalize:
            label = unicodedata.normalize("NFKD", label.strip()) \
                .encode("ascii", "ignore") \
                .decode("ascii", "ignore")
        label = maybe_normalize(label)
        label = validate_label(label)
        if ALPHABET and label:
            try:
                ALPHABET.encode(label)
            except KeyError:
                label = None
        return label
예제 #11
0
    parser = argparse.ArgumentParser(description="Import German Distant Speech (TUDA)")
    parser.add_argument("base_dir", help="Directory containing all data")
    parser.add_argument(
        "--max_duration",
        type=int,
        default=10000,
        help="Maximum sample duration in milliseconds",
    )
    parser.add_argument(
        "--normalize",
        action="store_true",
        help="Converts diacritic characters to their base ones",
    )
    parser.add_argument(
        "--alphabet",
        help="Exclude samples with characters not in provided alphabet file",
    )
    parser.add_argument(
        "--keep_archive",
        type=bool,
        default=True,
        help="If downloaded archives should be kept",
    )
    return parser.parse_args()


if __name__ == "__main__":
    CLI_ARGS = handle_args()
    ALPHABET = Alphabet(CLI_ARGS.alphabet) if CLI_ARGS.alphabet else None
    download_and_prepare()
예제 #12
0
    )
    parser.add_argument(
        "--skiplist",
        type=str,
        default="",
        help="Directories / books to skip, comma separated",
    )
    parser.add_argument(
        "--language", required=True, type=str, help="Dataset language to use"
    )
    return parser.parse_args()


if __name__ == "__main__":
    CLI_ARGS = handle_args()
    ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
    SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(","))
    validate_label = get_validate_label(CLI_ARGS)

    def label_filter(label):
        if CLI_ARGS.normalize:
            label = (
                unicodedata.normalize("NFKD", label.strip())
                .encode("ascii", "ignore")
                .decode("ascii", "ignore")
            )
        label = validate_label(label)
        if ALPHABET and label and not ALPHABET.CanEncode(label):
            label = None
        return label
class ArchiveImporter:
    def __init__(self,
                 corpus_name,
                 archive_url,
                 extract_dir=None,
                 output_path=None,
                 data_dir=None,
                 csv_append_mode=False,
                 filter_alphabet=None):
        self.corpus_name = corpus_name
        self.archive_url = archive_url
        # Make archive_name from archive_filename
        self.archive_filename = self.archive_url.rsplit('/', 1)[-1]
        # os.path.splitext:
        # tar.gz: will split ("file.name.tar",".gz")
        # but will split correctly "cnz_1.0.0.zip" into ("cnz_1.0.0","zip")
        self.archive_name = os.path.splitext(self.archive_filename)[0]
        if self.archive_name.endswith(".tar"):
            self.archive_name = self.archive_name.replace(".tar", "")
        self.extract_dir = self.archive_name
        if extract_dir is not None:
            self.extract_dir = extract_dir
        # Making path absolute root data or prefered from param data_dir
        self.dataset_path = os.path.abspath(
            self.corpus_name) if data_dir == None else os.path.join(
                data_dir, self.corpus_name)

        self.origin_data_path = os.path.join(
            self.dataset_path, "origin") if data_dir == None else data_dir

        if (output_path == None):
            #default
            importers_output_dir = os.path.abspath(BASE_OUTPUT_FOLDER_NAME)
            if not path.exists(importers_output_dir):
                print('No path "%s" - creating ...' % importers_output_dir)
                makedirs(importers_output_dir)

            self.dataset_output_path = os.path.join(importers_output_dir,
                                                    self.corpus_name)
        else:
            ##exernal dir
            self.dataset_output_path = os.path.join(output_path,
                                                    self.corpus_name)

        self.csv_append_mode = csv_append_mode
        self.filter_max_secs = MAX_SECS  ##filter for single clips max duration in second
        self.filter_min_secs = MIN_SECS  ##filter for single clips min duration in second

        if (data_dir != None):
            self.csv_wav_absolute_path = True
        else:
            ##default
            ##relative path from importers_output
            self.csv_wav_absolute_path = False

        ########################
        self.ALPHABET = Alphabet(
            filter_alphabet) if filter_alphabet != None else None
        ##SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(","))
        ##validate_label = get_validate_label(CLI_ARGS)

    def run(self):
        self._download_and_preprocess_data()

    # Validate and normalize transcriptions. Returns a cleaned version of the label
    # or None if it's invalid.
    def validate_label(self, label):

        # For now we can only handle [a-z ']
        if re.search(r"[0-9]|[(<\[\]&*{]", label) is not None:
            return None

        label = label.replace("-", " ")
        label = label.replace("_", " ")
        label = re.sub("[ ]{2,}", " ", label)
        label = label.replace(".", "")
        label = label.replace(",", "")
        label = label.replace(";", "")
        label = label.replace("?", "")
        label = label.replace("!", "")
        label = label.replace(":", "")
        label = label.replace("\"", "")
        ##
        label = label.replace("î", "i")  ##on mailabs dataset
        ##
        label = label.strip()
        label = label.lower()

        ##DEBUG - decomment for checking normalization char by char
        #for c in label:
        #    if(c not in DEBUG_ALPHABET):
        #        print('CHECK char:'+ c)

        return label if label else None

    # Validate and normalize transcriptions. Returns a cleaned version of the label
    # or None if it's invalid.
    def __validate_label(self, label):

        ## single apex
        #label = label.replace("’", "'") ##on siwis dataset
        ##label = label.replace("-", " ") ##on siwis dataset
        #label = label.replace("ï", "i") ##on siwis dataset

        ##
        label = label.strip()
        label = label.lower()

        ##TEMP - deccoment for check normalization to do
        for c in label:
            if (c not in DEBUG_ALPHABET):
                print('CHECK char:' + label)
                break

        return label if label else None

    def preprocess_trascript(self, transcript):

        #if CLI_ARGS.normalize:
        #    label = (
        #        unicodedata.normalize("NFKD", label.strip())
        #        .encode("ascii", "ignore")
        #        .decode("ascii", "ignore")
        #    )

        transcript = self.validate_label(transcript)
        if self.ALPHABET and transcript and not self.ALPHABET.CanEncode(
                transcript):
            print('Alphabet not encode: {} '.format(transcript))
            transcript = None

        return transcript

    def _download_and_preprocess_data(self):

        if not path.exists(self.dataset_output_path):
            print('No path "%s" - creating ...' % self.dataset_output_path)
            makedirs(self.dataset_output_path)

        archive_filename = self.archive_url.rsplit('/', 1)[-1]
        # Conditionally download data
        extracted_path = os.path.join(self.origin_data_path, self.extract_dir)
        if not os.path.exists(extracted_path):
            archive_path = maybe_download(archive_filename,
                                          self.origin_data_path,
                                          self.archive_url)
            # Conditionally extract common voice data
            self._maybe_extract(self.origin_data_path, self.extract_dir,
                                archive_path)

        ##get corpus:  audio_file_names + transcriptions
        print('Filter audio file and parse transcript...')
        corpus = self.get_corpus()

        # Conditionally convert CSV files and mp3/wav data to DeepSpeech CSVs and wav
        self._maybe_convert_sets(corpus)

    def _maybe_extract(self, target_dir, extract_dir, archive_path):
        # If target_dir/extract_dir does not exist, extract archive in target_dir
        extracted_path = os.path.join(target_dir, extract_dir)
        if not os.path.exists(extracted_path):
            print(f"No directory {extracted_path} - extracting archive...")

            ##check file if zip or tar
            if (archive_path.endswith('.zip')):
                ##extraxt zip file
                with ZipFile(archive_path, "r") as zipobj:
                    # Extract all the contents of zip file in current directory
                    zipobj.extractall(target_dir)
            else:
                ##extract other gzip, bz2 and lzma
                tar = tarfile.open(archive_path)
                tar.extractall(target_dir)
                tar.close()
        else:
            print(
                f"Found directory {extracted_path} - not extracting it from archive."
            )

    ##override this to use full functionality
    def get_corpus(self) -> Corpus:
        print('must be implemented in importer')

    ## OPTIONAL: must be implemented in importer if exist speaker id information
    def get_speaker_id(self, audio_file_path):
        return ""

    def _maybe_convert_sets(self, corpus: Corpus):

        samples = corpus.audios
        num_samples = len(samples)
        if (num_samples == 0):
            return

        ## all examples are processed, even if the resample is not necessary, the duration or other filters should be evaluated
        samples = [[a, corpus.make_wav_resample, corpus.utterences[a]]
                   for a in corpus.audios]
        #self.one_sample(samples[23])
        # Mutable counters for the concurrent embedded routine
        counter = get_counter()
        print(f"Converting audio files to wav {SAMPLE_RATE}hz Mono")
        pool = Pool()
        bar = progressbar.ProgressBar(max_value=num_samples,
                                      widgets=SIMPLE_BAR)
        rows = []
        for i, processed in enumerate(pool.imap_unordered(
                self.one_sample, samples),
                                      start=1):
            counter += processed[0]
            rows += processed[1]
            bar.update(i)
        bar.update(num_samples)
        pool.close()
        pool.join()

        ########################################
        ## filtered rows data are evaluated in write_csv
        self._write_csv(corpus, rows)

    def _maybe_convert_wav(self, orig_filename, wav_filename):
        ## MP2/MP3 (with optional libmad, libtwolame and libmp3lame libraries)  ## http://sox.sourceforge.net/Docs/Features
        if not os.path.exists(wav_filename):
            transformer = sox.Transformer()
            transformer.convert(samplerate=SAMPLE_RATE,
                                n_channels=N_CHANNELS,
                                bitdepth=BITDEPTH)
            try:
                transformer.build(str(orig_filename), str(wav_filename))
            except (sox.core.SoxError, sox.core.SoxiError) as ex:
                print("SoX processing error", ex, orig_filename, wav_filename)

    ##overrider this to filter
    def row_validation(self, filename, duration, comments):
        ##return True
        return True

    def one_sample(self, sample):

        delete_original_if_resampled = True
        ##set to false if you want run importer more time (ex. local test)
        #delete_original_if_resampled = False

        orig_filename = sample[0]
        make_wav_resample = sample[1]
        original_trascription = sample[2]

        head_f, f = ntpath.split(orig_filename)

        ##if is wav files we keep the original to carry out the import several times (for regeneration csv files)
        if (make_wav_resample and is_audio_wav_file(orig_filename)):
            converted_folder = os.path.join(os.path.dirname(orig_filename),
                                            'converted')
            if not os.path.exists(converted_folder):
                ##catch multiprocessor makedirs error
                try:
                    makedirs(converted_folder)
                except:
                    pass

            wav_filename = os.path.join(converted_folder, f)
        else:
            # Storing wav files next to the audio filename ones - just with a different suffix
            wav_filename = path.splitext(orig_filename)[0] + ".wav"

        ##Note: to get frames/duration for mp3/wav audio we not use soxi command but sox.file_info.duration(
        ##soxi command is not present in Windows sox distribution  - see this  https://github.com/rabitt/pysox/pull/74

        duration = -1
        try:
            duration = sox.file_info.duration(orig_filename)
        except:
            ## some mp3 in lablita got in error
            print(
                'sox.file_info.duration error on file {}, retrieve duration via filesize'
                .format(orig_filename))
            pass

        comments = ""
        try:
            comments = sox.file_info.comments(orig_filename)
        except (UnicodeError, sox.SoxiError) as e:
            try:
                completedProcess = subprocess.run(
                    ["soxi", "-a", orig_filename], stdout=subprocess.PIPE)
                comments = completedProcess.stdout.decode("utf-8", "ignore")
            except:
                pass

        if (len(comments) > 0):
            comments = comments.replace('\r', '')  ##
            comments = comments.replace('\n', '|')  ## \n is csv line separator

        if (make_wav_resample):
            self._maybe_convert_wav(orig_filename, wav_filename)

        file_size = -1
        if os.path.exists(wav_filename):
            file_size = path.getsize(wav_filename)
            if (duration == -1):
                ##retrieve duration from file size
                ##duration = (file_size - 44) / 16000 / 2
                ## time = FileLength / (Sample Rate * Channels * Bits per sample /8)
                duration = file_size / (SAMPLE_RATE * N_CHANNELS * BITDEPTH /
                                        8)

        frames = duration * SAMPLE_RATE

        is_valid = self.row_validation(orig_filename, duration, comments)

        label = self.preprocess_trascript(
            original_trascription
        )  ##label not managed  ##validate_label(sample[1])

        speaker_id = ''
        ##get speaker id info
        if (label != None):
            ##
            try:
                speaker_id = self.get_speaker_id(orig_filename)
            except Exception as e:
                print('get_speaker_id error: ' + str(e))

        rows = []
        counter = get_counter()
        if file_size == -1:
            # Excluding samples that failed upon conversion
            print(f'Conversion failed {orig_filename}')
            counter["failed"] += 1
        elif label is None or label == '' or not is_valid:
            # Excluding samples that failed on label validation
            print('Exclude label ' + original_trascription)
            counter["invalid_label"] += 1
        elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
            # Excluding samples that are too short to fit the transcript
            counter["too_short"] += 1
        elif frames / SAMPLE_RATE > self.filter_max_secs:
            # Excluding very long samples to keep a reasonable batch-size
            print(
                f' Clips too long, {str(frames / SAMPLE_RATE)}  - {orig_filename}'
            )

            counter["too_long"] += 1
        else:
            # This one is good - keep it for the target CSV
            rows.append((wav_filename, file_size, label, speaker_id, duration,
                         comments, original_trascription))
            counter["imported_time"] += frames
            ##if need free space remove original
            if (delete_original_if_resampled and make_wav_resample):
                os.remove(orig_filename)

        counter["all"] += 1
        counter["total_time"] += frames
        return (counter, rows)

    def one_sample_librosa(self, sample):

        import librosa

        mp3_wav_filename = sample[0]
        make_wav_resample = sample[1]
        # Storing wav files next to the audio filename ones - just with a different suffix
        wav_filename = path.splitext(mp3_wav_filename)[0] + ".wav"

        duration = -1
        comments = ""
        audioData, frameRate = None, None
        try:
            ##load data and convert to mono
            ## Warning -  libsndfile does not (yet/currently) support the mp3 format. see  https://github.com/librosa/librosa/issues/1015
            ## TODO: Installing ffmpeg to FIX error audioread.exceptions.NoBackendError - Lablita mp3 - problem is a missing ogg vorbis codec for audioread -  see also : https://github.com/librosa/librosa/issues/219
            audioData, frameRate = librosa.load(mp3_wav_filename,
                                                sr=SAMPLE_RATE,
                                                mono=True)
            duration = librosa.get_duration(y=audioData, sr=SAMPLE_RATE)
        except Exception as e:
            raise (e)
            print(str(e))
            print(
                'error load audio data with Librosa lib, retrieve duration via filesize - {}'
                .format(mp3_wav_filename))
            pass

        if (make_wav_resample and audioData != None):
            ## Maybe convert wav whith Librosa
            if not os.path.exists(wav_filename):
                ##load audio
                ##y, sr = librosa.load(mp3_filename, sr=SAMPLE_RATE)

                ##resample stereo to mono
                #y_mono = librosa.to_mono(y)
                ##librosa.resample(y,sr,)
                librosa.output.write_wav(wav_filename, audioData, SAMPLE_RATE)

        file_size = -1
        if os.path.exists(wav_filename):
            file_size = path.getsize(wav_filename)
            if (duration == -1):
                ##retrieve duration from file size
                ##duration = (file_size - 44) / 16000 / 2
                ## time = FileLength / (Sample Rate * Channels * Bits per sample /8)
                duration = file_size / (SAMPLE_RATE * N_CHANNELS * BITDEPTH /
                                        8)

        frames = duration * SAMPLE_RATE

        is_valid = self.row_validation(mp3_wav_filename, duration, comments)

        label = ''  ##label not managed  ##validate_label(sample[1])
        rows = []
        counter = get_counter()
        if file_size == -1:
            # Excluding samples that failed upon conversion
            print(f'Conversion failed {mp3_wav_filename}')
            counter["failed"] += 1
        elif label is None or not is_valid:
            # Excluding samples that failed on label validation
            counter["invalid_label"] += 1
        elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
            # Excluding samples that are too short to fit the transcript
            counter["too_short"] += 1
        elif frames / SAMPLE_RATE > self.filter_max_secs:
            # Excluding very long samples to keep a reasonable batch-size
            print(
                f' Clips too long, {str(frames / SAMPLE_RATE)}  - {mp3_wav_filename}'
            )

            counter["too_long"] += 1
        else:
            # This one is good - keep it for the target CSV
            rows.append(
                (mp3_wav_filename, file_size, label, duration, comments))
            counter["imported_time"] += frames
        counter["all"] += 1
        counter["total_time"] += frames
        return (counter, rows)

    def ___one_sample(sample):
        if is_audio_file(sample):
            y, sr = librosa.load(sample, sr=16000)

            # Trim the beginning and ending silence
            yt, index = librosa.effects.trim(y)  # pylint: disable=unused-variable

            duration = librosa.get_duration(yt, sr)
            if duration > MAX_SECS or duration < MIN_SECS:
                os.remove(sample)
            else:
                librosa.output.write_wav(sample, yt, sr)

    def _write_csv(self, corpus: Corpus, filtered_rows):

        print("\n")
        print("Writing CSV files")
        audios = corpus.audios
        utterences = corpus.utterences
        csv_data = []

        csv_columns = FIELDNAMES_CSV_FULL

        samples_len = len(audios)
        for row_data in filtered_rows:

            wav_filename = row_data[0]
            file_size = row_data[1]
            transcript_processed = row_data[2]
            speaker_id = row_data[3]
            duration = row_data[4]
            comments = row_data[5]
            source_transcript = row_data[6]

            wav_file_path = None
            if (self.csv_wav_absolute_path):
                ## audio file
                #   audio files and  output csv are on different paths
                ##audio file absolute path
                wav_file_path = os.path.abspath(wav_filename)
            else:
                ##make relative path audio file
                wav_file_path = os.path.relpath(wav_filename,
                                                self.origin_data_path)

            csv_row = dict(
                wav_filename=wav_file_path,
                wav_filesize=file_size,
                transcript=transcript_processed,
                speaker_id=speaker_id,
                duration=duration,
                comments=comments,
                source_transcript=
                source_transcript  ##we save original transcript
            )
            csv_data.append(csv_row)

        #shuffle set
        random.seed(76528)
        random.shuffle(csv_data)

        train_len = int(samples_len * corpus.datasets_sizes[0])
        test_len = int(samples_len * corpus.datasets_sizes[1])
        if (samples_len < train_len + test_len):
            raise ('size of the test dataset must be less than {}'.format(
                str(samples_len - train_len)))

        dev_len = samples_len - train_len - test_len
        train_data = csv_data[:train_len]
        dev_data = csv_data[train_len:train_len + test_len]
        test_data = csv_data[train_len + test_len:]

        file_open_mode = 'a' if self.csv_append_mode else 'w'

        target_csv_template = os.path.join(self.dataset_output_path, "{}.csv")
        with open(target_csv_template.format("train_full"),
                  file_open_mode,
                  encoding="utf-8",
                  newline="") as train_full_csv_file:
            with open(target_csv_template.format("train"),
                      file_open_mode,
                      encoding="utf-8",
                      newline="") as train_csv_file:
                with open(target_csv_template.format("dev"),
                          file_open_mode,
                          encoding="utf-8",
                          newline="") as dev_csv_file:
                    with open(target_csv_template.format("test"),
                              file_open_mode,
                              encoding="utf-8",
                              newline="") as test_csv_file:

                        train_full_writer = csv.DictWriter(
                            train_full_csv_file,
                            dialect='excel-tab',
                            fieldnames=FIELDNAMES_CSV_FULL)
                        if not self.csv_append_mode:
                            train_full_writer.writeheader()
                        train_writer = csv.DictWriter(
                            train_csv_file,
                            dialect='excel-tab',
                            fieldnames=FIELDNAMES_CSV_FULL)
                        if not self.csv_append_mode:
                            train_writer.writeheader()
                        dev_writer = csv.DictWriter(
                            dev_csv_file,
                            dialect='excel-tab',
                            fieldnames=FIELDNAMES_CSV_FULL)
                        if not self.csv_append_mode:
                            dev_writer.writeheader()
                        test_writer = csv.DictWriter(
                            test_csv_file,
                            dialect='excel-tab',
                            fieldnames=FIELDNAMES_CSV_FULL)
                        if not self.csv_append_mode:
                            test_writer.writeheader()

                        ##train full
                        for row in csv_data:
                            train_full_writer.writerow(row)
                        ##train
                        for row in train_data:
                            train_writer.writerow(row)
                        ##dev
                        for row in dev_data:
                            dev_writer.writerow(row)
                        ##test
                        for row in test_data:
                            test_writer.writerow(row)

        print(f"Wrote {len(csv_data)} entries")
예제 #14
0
text_file = open("chars_small.txt", "w", encoding='utf-8')
text_file.write('\n'.join([x if x != '#' else '\\#' for x in list(classes)]))
text_file.close()


def softmax(matrix):
    time_steps, _ = matrix.shape
    result = np.zeros(matrix.shape)
    for t in range(time_steps):
        e = np.exp(matrix[t, :])
        result[t, :] = e / np.sum(e)
    return result


def load_rnn_output(fn):
    return np.genfromtxt(fn, delimiter=';')[:, :-1]


alphabet = Alphabet(os.path.abspath("chars_small.txt"))
crnn_output = softmax(load_rnn_output('./rnn_output.csv'))
res = ctc_beam_search_decoder(probs_seq=crnn_output,
                              alphabet=alphabet,
                              beam_size=25,
                              scorer=Scorer(alphabet=alphabet,
                                            scorer_path='iam.scorer',
                                            alpha=0.75,
                                            beta=1.85))
# predicted: the fake friend of the family has to
# actual: the fake friend of the family, like the
print(res[0][1])
예제 #15
0
def init_worker(params):
    global FILTER_OBJ  # pylint: disable=global-statement
    validate_label = get_validate_label(params)
    alphabet = Alphabet(
        params.filter_alphabet) if params.filter_alphabet else None
    FILTER_OBJ = LabelFilter(params.normalize, alphabet, validate_label)
예제 #16
0
def initialize_globals():
    c = AttrDict()

    # Augmentations
    c.augmentations = parse_augmentations(FLAGS.augment)
    if len(c.augmentations
           ) > 0 and FLAGS.feature_cache and FLAGS.cache_for_epochs == 0:
        log_warn(
            'Due to current feature-cache settings the exact same sample augmentations of the first '
            'epoch will be repeated on all following epochs. This could lead to unintended over-fitting. '
            'You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs.'
        )

    # Caching
    if FLAGS.cache_for_epochs == 1:
        log_warn(
            '--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it.'
        )

    # Read-buffer
    FLAGS.read_buffer = parse_file_size(FLAGS.read_buffer)

    # Set default dropout rates
    if FLAGS.dropout_rate2 < 0:
        FLAGS.dropout_rate2 = FLAGS.dropout_rate
    if FLAGS.dropout_rate3 < 0:
        FLAGS.dropout_rate3 = FLAGS.dropout_rate
    if FLAGS.dropout_rate6 < 0:
        FLAGS.dropout_rate6 = FLAGS.dropout_rate

    # Set default checkpoint dir
    if not FLAGS.checkpoint_dir:
        FLAGS.checkpoint_dir = xdg.save_data_path(
            os.path.join('deepspeech', 'checkpoints'))

    if FLAGS.load_train not in ['last', 'best', 'init', 'auto']:
        FLAGS.load_train = 'auto'

    if FLAGS.load_evaluate not in ['last', 'best', 'auto']:
        FLAGS.load_evaluate = 'auto'

    # Set default summary dir
    if not FLAGS.summary_dir:
        FLAGS.summary_dir = xdg.save_data_path(
            os.path.join('deepspeech', 'summaries'))

    # Standard session configuration that'll be used for all new sessions.
    c.session_config = tfv1.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=FLAGS.log_placement,
        inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
        intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads,
        gpu_options=tfv1.GPUOptions(allow_growth=FLAGS.use_allow_growth))

    # CPU device
    c.cpu_device = '/cpu:0'

    # Available GPU devices
    c.available_devices = get_available_gpus(c.session_config)

    # If there is no GPU available, we fall back to CPU based operation
    if not c.available_devices:
        c.available_devices = [c.cpu_device]

    if FLAGS.bytes_output_mode:
        c.alphabet = UTF8Alphabet()
    else:
        c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))

    # Geometric Constants
    # ===================

    # For an explanation of the meaning of the geometric constants, please refer to
    # doc/Geometry.md

    # Number of MFCC features
    c.n_input = 26  # TODO: Determine this programmatically from the sample rate

    # The number of frames in the context
    c.n_context = 9  # TODO: Determine the optimal value using a validation data set

    # Number of units in hidden layers
    c.n_hidden = FLAGS.n_hidden

    c.n_hidden_1 = c.n_hidden

    c.n_hidden_2 = c.n_hidden

    c.n_hidden_5 = c.n_hidden

    # LSTM cell state dimension
    c.n_cell_dim = c.n_hidden

    # The number of units in the third layer, which feeds in to the LSTM
    c.n_hidden_3 = c.n_cell_dim

    # Units in the sixth layer = number of characters in the target language plus one
    c.n_hidden_6 = c.alphabet.GetSize() + 1  # +1 for CTC blank label

    # Size of audio window in samples
    if (FLAGS.feature_win_len * FLAGS.audio_sample_rate) % 1000 != 0:
        log_error(
            '--feature_win_len value ({}) in milliseconds ({}) multiplied '
            'by --audio_sample_rate value ({}) must be an integer value. Adjust '
            'your --feature_win_len value or resample your audio accordingly.'
            ''.format(FLAGS.feature_win_len, FLAGS.feature_win_len / 1000,
                      FLAGS.audio_sample_rate))
        sys.exit(1)

    c.audio_window_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_len /
                                                        1000)

    # Stride for feature computations in samples
    if (FLAGS.feature_win_step * FLAGS.audio_sample_rate) % 1000 != 0:
        log_error(
            '--feature_win_step value ({}) in milliseconds ({}) multiplied '
            'by --audio_sample_rate value ({}) must be an integer value. Adjust '
            'your --feature_win_step value or resample your audio accordingly.'
            ''.format(FLAGS.feature_win_step, FLAGS.feature_win_step / 1000,
                      FLAGS.audio_sample_rate))
        sys.exit(1)

    c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step /
                                                      1000)

    if FLAGS.one_shot_infer:
        if not path_exists_remote(FLAGS.one_shot_infer):
            log_error(
                'Path specified in --one_shot_infer is not a valid file.')
            sys.exit(1)

    if FLAGS.train_cudnn and FLAGS.load_cudnn:
        log_error('Trying to use --train_cudnn, but --load_cudnn '
                  'was also specified. The --load_cudnn flag is only '
                  'needed when converting a CuDNN RNN checkpoint to '
                  'a CPU-capable graph. If your system is capable of '
                  'using CuDNN RNN, you can just specify the CuDNN RNN '
                  'checkpoint normally with --save_checkpoint_dir.')
        sys.exit(1)

    # If separate save and load flags were not specified, default to load and save
    # from the same dir.
    if not FLAGS.save_checkpoint_dir:
        FLAGS.save_checkpoint_dir = FLAGS.checkpoint_dir

    if not FLAGS.load_checkpoint_dir:
        FLAGS.load_checkpoint_dir = FLAGS.checkpoint_dir

    ConfigSingleton._config = c  # pylint: disable=protected-access
예제 #17
0
def main():
    alphabet_txt = os.path.join(LANG.model_dir, 'alphabet.txt')
    raw_txt_gz = os.path.join(LANG.model_dir, 'raw.txt.gz')
    unprepared_txt = os.path.join(LANG.model_dir, 'unprepared.txt')
    prepared_txt = os.path.join(LANG.model_dir, 'prepared.txt')
    vocabulary_txt = os.path.join(LANG.model_dir, 'vocabulary.txt')
    unfiltered_arpa = os.path.join(LANG.model_dir, 'unfiltered.arpa')
    filtered_arpa = os.path.join(LANG.model_dir, 'filtered.arpa')
    lm_binary = os.path.join(LANG.model_dir, 'lm.binary')
    kenlm_scorer = os.path.join(LANG.model_dir, 'kenlm.scorer')
    temp_prefix = os.path.join(LANG.model_dir, 'tmp')

    section('Writing alphabet file', empty_lines_before=1)
    with open(alphabet_txt, 'w', encoding='utf-8') as alphabet_file:
        alphabet_file.write('\n'.join(LANG.alphabet) + '\n')

    redo = ARGS.force_download

    section('Downloading text data')
    redo = maybe_download(LANG.text_url, raw_txt_gz, force=redo)

    section('Unzipping text data')
    redo = maybe_ungzip(raw_txt_gz, unprepared_txt, force=redo)

    redo = redo or ARGS.force_prepare

    section('Preparing text and building vocabulary')
    if redo or not os.path.isfile(prepared_txt) or not os.path.isfile(vocabulary_txt):
        redo = True
        announce('Preparing {} shards of "{}"...'.format(ARGS.workers, unprepared_txt))
        counters = Queue(ARGS.workers)
        source_bytes = os.path.getsize(unprepared_txt)
        aggregator_process = Process(target=aggregate_counters, args=(vocabulary_txt, source_bytes, counters))
        aggregator_process.start()
        counter_processes = list(map(lambda index: Process(target=count_words, args=(index, counters)),
                                     range(ARGS.workers)))
        try:
            for p in counter_processes:
                p.start()
            for p in counter_processes:
                p.join()
            counters.put(STOP_TOKEN)
            aggregator_process.join()
            print('')
            partials = list(map(lambda i: get_partial_path(i), range(ARGS.workers)))
            join_files(partials, prepared_txt)
            for partial in partials:
                os.unlink(partial)
        except KeyboardInterrupt:
            aggregator_process.terminate()
            for p in counter_processes:
                p.terminate()
            raise
    else:
        announce('Files "{}" and \n\t"{}" existing - not preparing'.format(prepared_txt, vocabulary_txt))

    redo = redo or ARGS.force_generate

    section('Building unfiltered language model')
    if redo or not os.path.isfile(unfiltered_arpa):
        redo = True
        lmplz_args = [
            KENLM_BIN + '/lmplz',
            '--temp_prefix', temp_prefix,
            '--memory', '80%',
            '--discount_fallback',
            '--limit_vocab_file', vocabulary_txt,
            '--text', prepared_txt,
            '--arpa', unfiltered_arpa,
            '--skip', 'symbols',
            '--order', str(LANG.order)
        ]
        if len(LANG.prune) > 0:
            lmplz_args.append('--prune')
            lmplz_args.extend(list(map(str, LANG.prune)))
        subprocess.check_call(lmplz_args)
    else:
        announce('File "{}" existing - not generating'.format(unfiltered_arpa))

    section('Filtering language model')
    if redo or not os.path.isfile(filtered_arpa):
        redo = True
        with open(vocabulary_txt, 'rb') as vocabulary_file:
            vocabulary_content = vocabulary_file.read()
        subprocess.run([
            KENLM_BIN + '/filter',
            'single',
            'model:' + unfiltered_arpa,
            filtered_arpa
        ], input=vocabulary_content, check=True)
    else:
        announce('File "{}" existing - not filtering'.format(filtered_arpa))

    section('Generating binary representation')
    if redo or not os.path.isfile(lm_binary):
        redo = True
        subprocess.check_call([
            KENLM_BIN + '/build_binary',
            '-a', '255',
            '-q', '8',
            '-v',
            'trie',
            filtered_arpa,
            lm_binary
        ])
    else:
        announce('File "{}" existing - not generating'.format(lm_binary))

    section('Building scorer')
    if redo or not os.path.isfile(kenlm_scorer):
        redo = True
        words = set()
        vocab_looks_char_based = True
        with open(vocabulary_txt) as vocabulary_file:
            for line in vocabulary_file:
                for word in line.split():
                    words.add(word.encode())
                    if len(word) > 1:
                        vocab_looks_char_based = False
        announce("{} unique words read from vocabulary file.".format(len(words)))
        announce(
            "{} like a character based model.".format(
                "Looks" if vocab_looks_char_based else "Doesn't look"
            )
        )
        if ARGS.alphabet_mode == 'auto':
            use_utf8 = vocab_looks_char_based
        elif ARGS.alphabet_mode == 'utf8':
            use_utf8 = True
        else:
            use_utf8 = False
        serialized_alphabet = get_serialized_utf8_alphabet() if use_utf8 else LANG.get_serialized_alphabet()
        from ds_ctcdecoder import Scorer, Alphabet
        alphabet = Alphabet()
        err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet))
        if err != 0:
            announce('Error loading alphabet: {}'.format(err))
            sys.exit(1)
        scorer = Scorer()
        scorer.set_alphabet(alphabet)
        scorer.set_utf8_mode(use_utf8)
        scorer.reset_params(LANG.alpha, LANG.beta)
        scorer.load_lm(lm_binary)
        scorer.fill_dictionary(list(words))
        shutil.copy(lm_binary, kenlm_scorer)
        scorer.save_dictionary(kenlm_scorer, True)  # append, not overwrite
        announce('Package created in {}'.format(kenlm_scorer))
        announce('Testing package...')
        scorer = Scorer()
        scorer.load_lm(kenlm_scorer)
    else:
        announce('File "{}" existing - not building'.format(kenlm_scorer))