Exemplo n.º 1
0
def test_classification_json(tmpdir):
    json_path = json_data(tmpdir)

    data = SpeechRecognitionData.from_json(
        "file",
        "text",
        train_file=json_path,
        num_workers=0,
        batch_size=2,
    )
    model = SpeechRecognition(backbone=TEST_BACKBONE)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model, datamodule=data)
Exemplo n.º 2
0
def from_timit(
    val_split: float = 0.1,
    batch_size: int = 4,
    num_workers: int = 0,
    **input_transform_kwargs,
) -> SpeechRecognitionData:
    """Downloads and loads the timit data set."""
    download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip",
                  "./data")
    return SpeechRecognitionData.from_json(
        "file",
        "text",
        train_file="data/timit/train.json",
        test_file="data/timit/test.json",
        val_split=val_split,
        batch_size=batch_size,
        num_workers=num_workers,
        **input_transform_kwargs,
    )
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data")

datamodule = SpeechRecognitionData.from_json(
    "file",
    "text",
    train_file="data/timit/train.json",
    test_file="data/timit/test.json",
    batch_size=4,
)

# 2. Build the task
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict on audio files!
datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"], batch_size=4)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
Exemplo n.º 4
0
def test_audio_module_not_found_error():
    with pytest.raises(ModuleNotFoundError, match="[audio]"):
        SpeechRecognitionData.from_json("file", "text", train_file="", batch_size=1, num_workers=0)
Exemplo n.º 5
0
def test_from_json(tmpdir):
    json_path = json_data(tmpdir)
    dm = SpeechRecognitionData.from_json("file", "text", train_file=json_path, batch_size=1, num_workers=0)
    batch = next(iter(dm.train_dataloader()))
    assert DefaultDataKeys.INPUT in batch
    assert DefaultDataKeys.TARGET in batch
Exemplo n.º 6
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip",
              "./data")

datamodule = SpeechRecognitionData.from_json(
    input_fields="file",
    target_fields="text",
    train_file="data/timit/train.json",
    test_file="data/timit/test.json",
)

# 2. Build the task
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")

# 4. Predict on audio files!
predictions = model.predict(["data/timit/example.wav"])
print(predictions)

# 5. Save the model!