def test_tokenizer_decode_added_tokens(self):
        tokenizer = Wav2Vec2Tokenizer.from_pretrained(
            "facebook/wav2vec2-base-960h")
        tokenizer.add_tokens(["!", "?"])
        tokenizer.add_special_tokens({"cls_token": "$$$"})

        sample_ids = [
            [
                11,
                5,
                15,
                tokenizer.pad_token_id,
                15,
                8,
                98,
                32,
                32,
                33,
                tokenizer.word_delimiter_token_id,
                32,
                32,
                33,
                34,
                34,
            ],
            [
                24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77,
                tokenizer.pad_token_id, 34, 34
            ],
        ]
        batch_tokens = tokenizer.batch_decode(sample_ids)

        self.assertEqual(batch_tokens,
                         ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
    def test_tokenizer_decode_special(self):
        # TODO(PVP) - change to facebook
        tokenizer = Wav2Vec2Tokenizer.from_pretrained(
            "facebook/wav2vec2-base-960h")

        sample_ids = [
            [11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
            [24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
        ]
        sample_ids_2 = [
            [11, 5, 5, 5, 5, 5, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98],
            [
                24,
                22,
                5,
                tokenizer.pad_token_id,
                tokenizer.pad_token_id,
                tokenizer.pad_token_id,
                tokenizer.word_delimiter_token_id,
                24,
                22,
                5,
                77,
                tokenizer.word_delimiter_token_id,
            ],
        ]

        batch_tokens = tokenizer.batch_decode(sample_ids)
        batch_tokens_2 = tokenizer.batch_decode(sample_ids_2)
        self.assertEqual(batch_tokens, batch_tokens_2)
        self.assertEqual(batch_tokens, ["HELLO<unk>", "BYE BYE<unk>"])
    def __init__(self, device):
        self.device = device

        self.tokenizer = Wav2Vec2Tokenizer.from_pretrained(
            "facebook/wav2vec2-base-960h")
        self.raw_model = Wav2Vec2Model.from_pretrained(
            "facebook/wav2vec2-base-960h").to(self.device)
示例#4
0
    def test_inference_ctc_robust_batched(self):
        model = Wav2Vec2ForCTC.from_pretrained(
            "facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
        tokenizer = Wav2Vec2Tokenizer.from_pretrained(
            "facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)

        input_speech = self._load_datasamples(4)

        inputs = tokenizer(input_speech,
                           return_tensors="pt",
                           padding=True,
                           truncation=True)

        input_values = inputs.input_values.to(torch_device)
        attention_mask = inputs.attention_mask.to(torch_device)

        with torch.no_grad():
            logits = model(input_values, attention_mask=attention_mask).logits

        predicted_ids = torch.argmax(logits, dim=-1)
        predicted_trans = tokenizer.batch_decode(predicted_ids)

        EXPECTED_TRANSCRIPTIONS = [
            "a man said to the universe sir i exist",
            "sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
            "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
            "his instant panic was followed by a small sharp blow high on his chest",
        ]
        self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
示例#5
0
def get_predictions(test_dir_root: str, bs: int, extra_step: float, loading_step: float) -> None:

    device = torch.device("cuda:0") if torch.cuda.is_available() \
        else torch.device("cpu")

    # load model and tokenizer
    model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME).eval().to(device)
    tokenizer = Wav2Vec2Tokenizer.from_pretrained(MODEL_NAME)

    test_dir_root = Path(test_dir_root)

    # iterate over the files in the correct order
    with open(test_dir_root / "FILE_ORDER", "r") as f:
        wav_file_order = f.read().splitlines()

    token_predictions = {}
    for wf in wav_file_order:
        wf = f"{wf}.wav"
        print(f"Generating token predictions for {wf}")
        path_to_wav = test_dir_root / "wavs" / wf
        token_predictions[wf] = get_preds_for_wav(model, tokenizer, device, bs,
            path_to_wav, extra_step, loading_step)

    test_dir_root.mkdir(parents = True, exist_ok = True)
    path_to_preds = test_dir_root / "token_predictions.json"
    with open(path_to_preds, "w") as f:
        json.dump(token_predictions, f)

    print(f"Wav2Vec predictions saved at {path_to_preds}")
示例#6
0
def load_asr_model(device):
    """Load model"""
    print(f"[INFO]: Load the pre-trained ASR by {ASR_PRETRAINED_MODEL}.")
    model = Wav2Vec2ForCTC.from_pretrained(ASR_PRETRAINED_MODEL).to(device)
    tokenizer = Wav2Vec2Tokenizer.from_pretrained(ASR_PRETRAINED_MODEL)
    models = {"model": model, "tokenizer": tokenizer}
    return models
示例#7
0
    def test_inference_ctc_normal_batched(self):
        model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
        model.to(torch_device)
        tokenizer = Wav2Vec2Tokenizer.from_pretrained(
            "facebook/wav2vec2-base-960h", do_lower_case=True)

        input_speech = self._load_datasamples(2)

        inputs = tokenizer(input_speech,
                           return_tensors="pt",
                           padding=True,
                           truncation=True)

        input_values = inputs.input_values.to(torch_device)

        with torch.no_grad():
            logits = model(input_values).logits

        predicted_ids = torch.argmax(logits, dim=-1)
        predicted_trans = tokenizer.batch_decode(predicted_ids)

        EXPECTED_TRANSCRIPTIONS = [
            "a man said to the universe sir i exist",
            "sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore",
        ]
        self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
 def __init__(self, device):
     super(COVIDWav2Vec, self).__init__()
     self.tokenizer = Wav2Vec2Tokenizer.from_pretrained(
         "facebook/wav2vec2-base-960h"
     )
     self.model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
     self.linear = nn.Linear(768, 1)
     self.device = device
     return
示例#9
0
def get_embeddings(filename):
    tokenizer = Wav2Vec2Tokenizer.from_pretrained(
        "facebook/wav2vec2-base-960h")
    # load audio
    audio_input, _ = sf.read(filename)

    # transcribe
    input_values = tokenizer(audio_input, return_tensors="pt").input_values
    print(input_values)
    def test_pretrained_checkpoints_are_set_correctly(self):
        # this test makes sure that models that are using
        # group norm don't have their tokenizer return the
        # attention_mask
        for model_id in WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST:
            config = Wav2Vec2Config.from_pretrained(model_id)
            tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_id)

            # only "layer" feature extraction norm should make use of
            # attention_mask
            self.assertEqual(tokenizer.return_attention_mask, config.feat_extract_norm == "layer")
def process(PATH):
    audio, sampling_rate = librosa.load(PATH, sr=16000)
    tokenizer = Wav2Vec2Tokenizer.from_pretrained(
        "facebook/wav2vec2-base-960h")
    model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
    input_values = tokenizer(audio, return_tensors='pt').input_values
    logits = model(input_values).logits
    prediction = torch.argmax(logits, dim=-1)
    transcription = tokenizer.batch_decode(prediction)[0]
    print(transcription)
    return
示例#12
0
class Transcription():
    """
    Simple class to upload the data in the sound file and transcribe it.
    """
    tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
    model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

    #initialize file names
    origin_file = 'audio.wav'
    destination_file = 'rec4.wav'

    file_name = 'rec4.wav'
    file_path = os.path.join('.', file_name)

#    def __init__(self, origin_file):
#        self.origin_file = origin_file


    def change_filename(self):
        "Change the audio file from .oga to .wav"

        if os.path.exists(self.destination_file):
            os.remove(self.destination_file)

        process = subprocess.run(['ffmpeg', '-hide_banner','-i', self.origin_file, self.destination_file])
        if process.returncode != 0:
            raise Exception("Something went wrong")


    def map_to_array(self):
        "Read file and convert to a format that the model can accept"

        self.speech, self.sampling_rate = torchaudio.load(self.origin_file)
        self.resample_rate = 16000
        self.speech = librosa.resample(np.asarray(self.speech).reshape(-1,), self.sampling_rate, self.resample_rate)
        self.speech = librosa.to_mono(self.speech)
        return self.speech, self.resample_rate


    def indicate_transcription(self):
        "Transcribe"

        #self.change_filename()
        self.speech, self.sampling_rate = self.map_to_array()
        input_values = self.tokenizer(self.speech, return_tensors="pt", padding="longest").input_values
        logits = self.model(input_values).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = self.tokenizer.batch_decode(predicted_ids)
        transcription = ''.join(transcription)
        return transcription.lower()

    def __str__(self):
        return self.indicate_transcription()
示例#13
0
文件: transcribe.py 项目: lmmx/tap
def transcribe_audio_file(
    audio_file,
    #model_to_load="facebook/wav2vec2-base-960h",
    model_to_load="facebook/wav2vec2-large-960h-lv60-self",
):
    # load model and tokenizer
    tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_to_load)
    model = Wav2Vec2ForCTC.from_pretrained(model_to_load)
    audio_input, _ = librosa.load(audio_file, sr=16000)
    tok = tokenizer(audio_input, return_tensors="pt")
    input_values = tok.input_values
    logits = model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = tokenizer.batch_decode(predicted_ids)[0]
    return transcription
    def test_inference_masked_lm_normal(self):
        model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
        model.to(torch_device)
        tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)

        input_speech = self._load_datasamples(1)

        input_values = tokenizer(input_speech, return_tensors="pt").input_values.to(torch_device)

        with torch.no_grad():
            logits = model(input_values).logits

        predicted_ids = torch.argmax(logits, dim=-1)
        predicted_trans = tokenizer.batch_decode(predicted_ids)

        EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
        self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
示例#15
0
def load_model(root, device):
    """Load model"""
    pretrain_models = {
        "EN": "facebook/wav2vec2-large-960h-lv60-self",
        "DE": "jonatasgrosman/wav2vec2-large-xlsr-53-german",
        "FR": "jonatasgrosman/wav2vec2-large-xlsr-53-french",
        "IT": "jonatasgrosman/wav2vec2-large-xlsr-53-italian",
        "ES": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"
    }

    print(f"[INFO]: Load the pre-trained ASR by {pretrain_models[root]}.")
    model = Wav2Vec2ForCTC.from_pretrained(pretrain_models[root]).to(device)
    if root.upper() == "EN":
        tokenizer = Wav2Vec2Tokenizer.from_pretrained(pretrain_models[root])
    elif root.upper() in ["DE", "FR", "IT", "ES"]:
        tokenizer = Wav2Vec2Processor.from_pretrained(pretrain_models[root])
    else:
        print(f"{root} not available.")
        exit()

    models = {"model": model, "tokenizer": tokenizer}

    return models
示例#16
0
def asr_inference(tsv_path: Path, batch_size: int) -> None:

    # number of cpu and gpu devices
    n_gpu = torch.cuda.device_count()
    n_cpu = cpu_count()
    print(f"Number of cuda devices: {n_gpu} | Number of CPU cores: {n_cpu}")

    # specify main device and all devices (if gpu available)
    device_list = [torch.device(f"cuda:{i}") for i in range(n_gpu)]
    main_device = device_list[0] if n_gpu > 0 else torch.device("cpu")
    print(f"Main device: {main_device}")
    print(f"Parallel devices = {device_list}")

    # load model and tokenizer
    tokenizer = Wav2Vec2Tokenizer.from_pretrained(MODEL_NAME)
    model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME).eval()
    print(f"Loaded model and tokenizer: {MODEL_NAME}")

    if len(device_list) > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=device_list,
                                      output_device=main_device)

    model.to(main_device)

    # to store results
    wer_path = tsv_path.parent / f"{tsv_path.stem}_wer_results.json"

    # check if there are already predictions for some ids
    if wer_path.is_file():
        with open(wer_path, "r") as file:
            completed_ids = [json.loads(line)["id"] for line in file]
    else:
        completed_ids = []

    # load dataset and initialize dataloader
    dataset = AsrDataset(tsv_path, completed_ids)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            collate_fn=asr_collate_fn,
                            shuffle=False,
                            num_workers=n_cpu // 8,
                            drop_last=False)
    print("Loaded dataset and intialized dataloader")

    # for global scores
    start_time = time()
    wer_sum, n_examples = 0, 0

    print(
        f"Starting inference, results file: {wer_path} is updated every batch")

    # loop through the dataset
    with torch.no_grad():
        for (ids, audio, tgt_text) in tqdm(iter(dataloader),
                                           miniters=len(dataloader) // 100):

            tokenized_audio = tokenizer(audio,
                                        return_tensors="pt",
                                        padding="longest")
            input_values = tokenized_audio.input_values.to(main_device)
            attention_mask = tokenized_audio.attention_mask.to(main_device)

            # retrieve logits
            logits = model(input_values, attention_mask=attention_mask).logits

            # take argmax and decode
            predicted_ids = torch.argmax(logits, dim=-1)
            transcriptions = [
                trans.lower()
                for trans in tokenizer.batch_decode(predicted_ids)
            ]

            # record WER per datapoint
            wer_results = [
                wer(tgt_text[i], transcriptions[i]) for i in range(len(ids))
            ]

            # store results for this batch
            with open(wer_path, "a") as f:
                for i in range(len(ids)):
                    json.dump(
                        {
                            "id": ids[i],
                            "WER": wer_results[i],
                            "target text": tgt_text[i],
                            "transcription": transcriptions[i]
                        }, f)
                    f.write("\n")

            # for global scores
            wer_sum += sum(wer_results)
            n_examples += len(wer_results)

    if wer_sum != 0:
        n_data = len(dataset)
        total_time = time() - start_time
        avg_proc_time = total_time / n_data
        macro_wer = wer_sum / n_examples

        print(
            f"Finished inference for {n_data} datapoints in {total_time / 60} minutes"
        )
        print(
            f"Average processing time per datapoint = {avg_proc_time} seconds")
        print(f"Macro-averaged WER = {macro_wer}")

    else:
        print("Predictions file was already completed.")
示例#17
0
 def __init__(self):
     super().__init__()
     self.tokenizer = Wav2Vec2Tokenizer.from_pretrained(
         "facebook/wav2vec2-base-960h")
     self.model = Wav2Vec2ForCTC.from_pretrained(
         "facebook/wav2vec2-base-960h")
 def __init__(self):
     self.tokenizer = Wav2Vec2Tokenizer.from_pretrained(
         "facebook/wav2vec2-base-960h")
示例#19
0
    parser.add_argument(
        "--max_segm_len",
        type=str,
        required=True,
        help="Maximum segment length during segmentation in seconds.\
                Use two number separated by commas to run for each value in this range."
    )
    parser.add_argument("--min_pause_len",
                        type=float,
                        default=0.2,
                        help="Minimum allowed pause length in seconds")
    args = parser.parse_args()

    dataset_root = Path(args.dataset_root)

    tokenizer = Wav2Vec2Tokenizer.from_pretrained(MODEL_NAME)

    # find original segmentation file
    original_segmentation_file = next(dataset_root.glob("*en-de.yaml"))

    # run multiple times
    if "," in args.max_segm_len:
        (dataset_root / "own_segmentation").mkdir(parents=True, exist_ok=True)

        max_segm_len_range = range(*map(int, args.max_segm_len.split(",")))
        for max_segm_len_value in max_segm_len_range:
            segment_dataset(max_segm_len_value, args.min_pause_len,
                            dataset_root, tokenizer,
                            original_segmentation_file)
    # run once
    else:
示例#20
0
def segment_dataset(max_segm_len_secs: int, min_pause_len: float,
                    dataset_root: Path, tokenizer: Wav2Vec2Tokenizer,
                    path_to_original_segmentation: Path) -> None:

    # convert secs to wav2vec prediciton steps
    max_segm_len_steps = int(max_segm_len_secs / MS)
    min_pause_len = int(min_pause_len / MS)

    # load wav2vec token predicitons for all the wavs in the dataset
    with open(dataset_root / "token_predictions.json", "r") as f:
        predictions = json.load(f)

    segm_data = []

    print(
        f"Segmenting with max segmentation length = {max_segm_len_secs} secs.")

    for wav_file, preds in predictions.items():

        tokens = tokenizer.convert_ids_to_tokens(preds)

        # convert to a long string and replace the pad token with a space (easier to handle)
        tokens_text = "".join(tokens).replace("<pad>", " ")

        # apply segmentation algorithm
        segments = split_text_to_segments(tokens_text, max_segm_len_steps,
                                          min_pause_len)

        # duration and offset of each segment in seconds
        duration = np.array([len(segm) for segm in segments]) * MS
        duration_cumsum = np.insert(np.cumsum(duration)[1:], 0, 0)
        offset = duration_cumsum - duration
        offset[0] = 0

        # delete segments that do not contain words
        num_pauses = 0
        for i, segm in reversed(list(enumerate(segments))):
            if is_pause(segm):
                del segments[i]
                offset = np.delete(offset, i)
                duration = np.delete(duration, i)
                num_pauses += 1

        # expand each segment by EXTRA_MS steps before and after
        offset = list(offset - EXTRA_MS * MS)
        duration = list(duration + EXTRA_MS * MS)

        path_to_wav = str(dataset_root / wav_file)

        segm_data.extend([{
            "wav": path_to_wav,
            "offset": round(float(off), 2),
            "duration": round(float(dur), 2)
        } for off, dur in zip(offset, duration)])

        print(
            f"{wav_file} -> {len(segments) + num_pauses} segments ({num_pauses} pauses) "
        )

    new_segmentation_file = Path(f"_own_{max_segm_len_secs}.".join(
        str(path_to_original_segmentation).rsplit(".", maxsplit=1))).name
    path_to_new_segmentation = dataset_root / "own_segmentation" / new_segmentation_file
    with open(path_to_new_segmentation, "w") as f:
        yaml.dump(segm_data, f, default_flow_style=True)

    print(
        f"Segmentation for max_segm_len={max_segm_len_secs} saved at {path_to_new_segmentation}"
    )
示例#21
0
# HF: https://huggingface.co/facebook/wav2vec2-base-960h?s=09

from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
import torch
import librosa
import os
import json
from segmenter import Segmenter

# load model and tokenizer
#tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", cache_dir=os.getenv("cache_dir", "../models"))
#model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", cache_dir=os.getenv("cache_dir", "../models"))

# load model and tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h",
                                              cache_dir=os.getenv(
                                                  "cache_dir", "../../models"))
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h",
                                       cache_dir=os.getenv(
                                           "cache_dir", "../../models"))

audio_ds = [
    os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data',
                 'sample.mp3'),
    os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data',
                 'long_sample.mp3')
]

# create speech segmenter
seg = Segmenter(model_path=os.path.join(
    os.path.dirname(os.path.abspath(__file__)), 'speech_segmenter_models'),
示例#22
0
# #+++++++++
# print('+++++++++load tokenizer+++++++++')
# tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
# print('+++++++++load model+++++++++')
# model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

# print('+++++++++save tokenizer+++++++++')
# tokenizer.save_pretrained('/tokenizer/')
# print('+++++++++save model+++++++++')
# model.save_pretrained('/model/')

#load any audio file of your choice

#load model and tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

# with open("test_transcripts.txt", 'wt') as out:
#     out.write("/data/voshpde/asr_huggingface/transcriptions.txt")


def get_text(filename: str):
    try:
        speech, rate = librosa.load(filename, sr=16000)

        input_values = tokenizer(speech, return_tensors='pt').input_values
        #Store logits (non-normalized predictions)
        logits = model(input_values).logits

        #Store predicted id's
 def get_tokenizer(self, **kwargs):
     kwargs.update(self.special_tokens_map)
     return Wav2Vec2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
示例#24
0
import torch
import os
import librosa
from flask import Flask, render_template, request, jsonify, redirect

app = Flask(__name__)

## Download the pretrained model
#tokenizer = Wav2Vec2Tokenizer.from_pretrained('facebook/wav2vec2-base-960h')
#model = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-base-960h')
#tokenizer.save_pretrained('./model')
#model.save_pretrained('./model')

cur_dir = os.getcwd()
dir_path = os.path.join(cur_dir, "models")
tokenizer = Wav2Vec2Tokenizer.from_pretrained(dir_path)
model = Wav2Vec2ForCTC.from_pretrained(dir_path)


def transcribe(file):
    audio, rate = librosa.load(file, sr=16000)
    input_values = tokenizer(audio, padding='longest',
                             return_tensors='pt').input_values
    logits = model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = tokenizer.batch_decode(predicted_ids)
    return transcription[0]


@app.route("/", methods=["GET", "POST"])
def index():
示例#25
0
# @author Loreto Parisi (loretoparisi at gmail dot com)
# Copyright (c) 2020-2021 Loreto Parisi (loretoparisi at gmail dot com)
# Code adpated from https://pastebin.com/3wWj59uz
# Code adpated from [Nikita Schneider](https://twitter.com/DeepSchneider) https://twitter.com/DeepSchneider/status/1381179738824314880?s=20

import os
import librosa
import torch
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
import deepspeed

#tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h", cache_dir=os.getenv("cache_dir", "../models"))
#model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h", cache_dir=os.getenv("cache_dir", "../models"))

tokenizer = Wav2Vec2Tokenizer.from_pretrained(
    "facebook/wav2vec2-large-960h-lv60-self",
    cache_dir=os.getenv("cache_dir", "../models"))
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-960h-lv60-self",
    cache_dir=os.getenv("cache_dir", "../models"))

# DeepSpeed config
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '9994'
os.environ['RANK'] = "0"
os.environ['LOCAL_RANK'] = "0"
os.environ['WORLD_SIZE'] = "1"
config = {
    "train_batch_size": 8,
    "fp16": {
        "enabled": True,
示例#26
0
def load_model():
    tokenizer = Wav2Vec2Tokenizer.from_pretrained(
        "facebook/wav2vec2-base-960h")
    model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
    return tokenizer, model
示例#27
0
 def __init__(self, path, label):
     self.path = path
     self.label = label
     self.tokenizer = Wav2Vec2Tokenizer.from_pretrained(
         "facebook/wav2vec2-base-960h")