def test_stage_test_and_valid(tmpdir): csv_path = csv_data(tmpdir) dm = SpeechRecognitionData.from_csv( "file", "text", train_file=csv_path, val_file=csv_path, test_file=csv_path, batch_size=1, num_workers=0 ) batch = next(iter(dm.val_dataloader())) assert DefaultDataKeys.INPUT in batch assert DefaultDataKeys.TARGET in batch batch = next(iter(dm.test_dataloader())) assert DefaultDataKeys.INPUT in batch assert DefaultDataKeys.TARGET in batch
def test_classification_json(tmpdir): json_path = json_data(tmpdir) data = SpeechRecognitionData.from_json( "file", "text", train_file=json_path, num_workers=0, batch_size=2, ) model = SpeechRecognition(backbone=TEST_BACKBONE) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, datamodule=data)
def from_timit( val_split: float = 0.1, batch_size: int = 4, num_workers: int = 0, **input_transform_kwargs, ) -> SpeechRecognitionData: """Downloads and loads the timit data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data") return SpeechRecognitionData.from_json( "file", "text", train_file="data/timit/train.json", test_file="data/timit/test.json", val_split=val_split, batch_size=batch_size, num_workers=num_workers, **input_transform_kwargs, )
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import flash from flash.audio import SpeechRecognition, SpeechRecognitionData from flash.core.data.utils import download_data # 1. Create the DataModule download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data") datamodule = SpeechRecognitionData.from_json( "file", "text", train_file="data/timit/train.json", test_file="data/timit/test.json", batch_size=4, ) # 2. Build the task model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h") # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Predict on audio files! datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"], batch_size=4) predictions = trainer.predict(model, datamodule=datamodule) print(predictions)
def test_audio_module_not_found_error(): with pytest.raises(ModuleNotFoundError, match="[audio]"): SpeechRecognitionData.from_json("file", "text", train_file="", batch_size=1, num_workers=0)
def test_from_json(tmpdir): json_path = json_data(tmpdir) dm = SpeechRecognitionData.from_json("file", "text", train_file=json_path, batch_size=1, num_workers=0) batch = next(iter(dm.train_dataloader())) assert DefaultDataKeys.INPUT in batch assert DefaultDataKeys.TARGET in batch
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import flash from flash.audio import SpeechRecognition, SpeechRecognitionData from flash.core.data.utils import download_data # 1. Create the DataModule download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data") datamodule = SpeechRecognitionData.from_json( input_fields="file", target_fields="text", train_file="data/timit/train.json", test_file="data/timit/test.json", ) # 2. Build the task model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h") # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule, strategy="no_freeze") # 4. Predict on audio files! predictions = model.predict(["data/timit/example.wav"]) print(predictions) # 5. Save the model!