Beispiel #1
0
def set_audio_backend(backend):
    """Allow additional backend value, 'default'"""
    backends = torchaudio.list_audio_backends()
    if backend == 'default':
        if 'sox_io' in backends:
            be = 'sox_io'
        elif 'soundfile' in backends:
            be = 'soundfile'
        else:
            raise unittest.SkipTest('No default backend available')
    else:
        be = backend

    torchaudio.set_audio_backend(be)
Beispiel #2
0
def set_audio_backend(backend):
    """Allow additional backend value, 'default'"""
    backends = torchaudio.list_audio_backends()
    if backend == 'soundfile-new':
        be = 'soundfile'
        torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
    elif backend == 'default':
        if 'sox_io' in backends:
            be = 'sox_io'
        elif 'soundfile' in backends:
            be = 'soundfile'
            torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = True
        else:
            raise unittest.SkipTest('No default backend available')
    else:
        be = backend

    torchaudio.set_audio_backend(be)
Beispiel #3
0
import os
import tempfile
import unittest
from typing import Type, Iterable, Union
from contextlib import contextmanager
from shutil import copytree

import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio

_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio.list_audio_backends()


def get_asset_path(*paths):
    """Return full path of a test asset"""
    return os.path.join(_TEST_DIR_PATH, 'assets', *paths)


def create_temp_assets_dir():
    """
    Creates a temporary directory and moves all files from test/assets there.
    Returns a Tuple[string, TemporaryDirectory] which is the folder path
    and object.
    """
    tmp_dir = tempfile.TemporaryDirectory()
    copytree(os.path.join(_TEST_DIR_PATH, "assets"),
             os.path.join(tmp_dir.name, "assets"))
    return tmp_dir.name, tmp_dir
Beispiel #4
0
    dtype = None
    device = None
    backend = None

    def setUp(self):
        super().setUp()
        set_audio_backend(self.backend)


class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
    pass


def skipIfNoExec(cmd):
    return unittest.skipIf(
        shutil.which(cmd) is None, f'`{cmd}` is not available')


def skipIfNoModule(module, display_name=None):
    display_name = display_name or module
    return unittest.skipIf(not is_module_available(module),
                           f'"{display_name}" is not available')


skipIfNoSoxBackend = unittest.skipIf(
    'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(),
                               reason='CUDA not available')
skipIfNoExtension = skipIfNoModule('torchaudio._torchaudio',
                                   'torchaudio C++ extension')
Beispiel #5
0
def train(args):
    args = _parse_args(args)
    _LG.info("%s", args)

    args.save_dir.mkdir(parents=True, exist_ok=True)
    if "sox_io" in torchaudio.list_audio_backends():
        torchaudio.set_audio_backend("sox_io")

    start_epoch = 1
    if args.resume:
        checkpoint = torch.load(args.resume)
        if args.sample_rate != checkpoint["sample_rate"]:
            raise ValueError(
                "The provided sample rate ({args.sample_rate}) does not match "
                "the sample rate from the check point ({checkpoint['sample_rate']})."
            )
        if args.num_speakers != checkpoint["num_speakers"]:
            raise ValueError(
                "The provided #of speakers ({args.num_speakers}) does not match "
                "the #of speakers from the check point ({checkpoint['num_speakers']}.)"
            )
        start_epoch = checkpoint["epoch"]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _LG.info("Using: %s", device)

    model = _get_model(num_sources=args.num_speakers)
    model.to(device)

    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[device] if torch.cuda.is_available() else None
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    if args.resume:
        _LG.info("Loading parameters from the checkpoint...")
        model.module.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
    else:
        dist_utils.synchronize_params(
            str(args.save_dir / "tmp.pt"), device, model, optimizer
        )

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="max", factor=0.5, patience=3
    )

    train_loader, valid_loader, eval_loader = _get_dataloader(
        args.dataset,
        args.dataset_dir,
        args.num_speakers,
        args.sample_rate,
        args.batch_size,
    )

    num_train_samples = len(train_loader.dataset)
    num_valid_samples = len(valid_loader.dataset)
    num_eval_samples = len(eval_loader.dataset)

    _LG.info_on_master("Datasets:")
    _LG.info_on_master(" - Train: %s", num_train_samples)
    _LG.info_on_master(" - Valid: %s", num_valid_samples)
    _LG.info_on_master(" -  Eval: %s", num_eval_samples)

    trainer = conv_tasnet.trainer.Trainer(
        model,
        optimizer,
        train_loader,
        valid_loader,
        eval_loader,
        args.grad_clip,
        device,
        debug=args.debug,
    )

    log_path = args.save_dir / "log.csv"
    _write_header(log_path, args)
    dist_utils.write_csv_on_master(
        log_path,
        [
            "epoch",
            "learning_rate",
            "valid_si_snri",
            "valid_sdri",
            "eval_si_snri",
            "eval_sdri",
        ],
    )

    _LG.info_on_master("Running %s epochs", args.epochs)
    for epoch in range(start_epoch, start_epoch + args.epochs):
        _LG.info_on_master("=" * 70)
        _LG.info_on_master("Epoch: %s", epoch)
        _LG.info_on_master("Learning rate: %s", optimizer.param_groups[0]["lr"])
        _LG.info_on_master("=" * 70)

        t0 = time.monotonic()
        trainer.train_one_epoch()
        train_sps = num_train_samples / (time.monotonic() - t0)

        _LG.info_on_master("-" * 70)

        t0 = time.monotonic()
        valid_metric = trainer.validate()
        valid_sps = num_valid_samples / (time.monotonic() - t0)
        _LG.info_on_master("Valid: %s", valid_metric)

        _LG.info_on_master("-" * 70)

        t0 = time.monotonic()
        eval_metric = trainer.evaluate()
        eval_sps = num_eval_samples / (time.monotonic() - t0)
        _LG.info_on_master(" Eval: %s", eval_metric)

        _LG.info_on_master("-" * 70)

        _LG.info_on_master("Train: Speed: %6.2f [samples/sec]", train_sps)
        _LG.info_on_master("Valid: Speed: %6.2f [samples/sec]", valid_sps)
        _LG.info_on_master(" Eval: Speed: %6.2f [samples/sec]", eval_sps)

        _LG.info_on_master("-" * 70)

        dist_utils.write_csv_on_master(
            log_path,
            [
                epoch,
                optimizer.param_groups[0]["lr"],
                valid_metric.si_snri,
                valid_metric.sdri,
                eval_metric.si_snri,
                eval_metric.sdri,
            ],
        )

        lr_scheduler.step(valid_metric.si_snri)

        save_path = args.save_dir / f"epoch_{epoch}.pt"
        dist_utils.save_on_master(
            save_path,
            {
                "model": model.module.state_dict(),
                "optimizer": optimizer.state_dict(),
                "num_speakers": args.num_speakers,
                "sample_rate": args.sample_rate,
                "epoch": epoch,
            },
        )
Beispiel #6
0
from typing import Iterable, List, Tuple

import numpy as np
import torch
from torch import nn
import torchaudio
from pathos.threading import ThreadPool
from torchaudio.transforms import MFCC, Resample
import torchlibrosa

logger = logging.getLogger()

# Use sox_io backend if available
if (
    torchaudio.get_audio_backend() != "sox_io"
    and "sox_io" in torchaudio.list_audio_backends()
):
    torchaudio.set_audio_backend("sox_io")
    logger.debug("Set audio backend to sox_io")

# Required because as of 0.7.2 on OSX, torchaudio links its own OpenMP runtime in addition to pytorch
# This tells OpenMP not to crash when this happens.
if sys.platform == "darwin":
    os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


class AudioTooShortError(ValueError):
    pass


class TensorPreprocesser:
# https://chrisalbon.com/code/deep_learning/pytorch/basics/check_if_pytorch_is_using_gpu/

import torch
print("torch import successful")
import torchaudio
print("torchaudio import successful")
print("torchaudio backends:", torchaudio.list_audio_backends())
print("")