예제 #1
0
def main():
    TIMIT_CACHE_DIR = '/import/vision-eddydata/dm005_tmp/mixed_wavs_asteroid'
    train_snrs = [-25, -20, -15, -10, -5, 0, 5, 10, 15]
    test_snrs = [-30, -25, -20, -15, -10, -5, 0]

    timit_train_misc = TimitDataset.load_with_cache(
        '../../../datasets/TIMIT',
        '../../../datasets/noises-train',
        cache_dir=TIMIT_CACHE_DIR,
        snrs=train_snrs,
        root_seed=42,
        prefetch_mixtures=False,
        dset_name='train-misc',
        subset='train',
        track_duration=48000)

    train_set, val_set = train_val_split(timit_train_misc)

    BATCH_SIZE = 64
    NUM_WORKERS = 10

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        drop_last=True,
    )

    val_loader = DataLoader(
        val_set,
        shuffle=False,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        drop_last=True,
    )

    unetgan_module = UNetGAN(lr_g=2e-4, lr_d=2e-4)

    trainer = Trainer(
        max_epochs=102,
        gpus=-1,
        accelerator='dp',
        resume_from_checkpoint=
        'logs/lightning_logs/version_0/checkpoints/epoch=99-step=395190.ckpt')
    trainer.fit(unetgan_module, train_loader, val_loader)
예제 #2
0
def prepare_dataloaders(args):
    train_noise_dir = PurePath(args.dset.noises) / 'noises-train-drones'
    noises = CachedWavSet(train_noise_dir, sample_rate=args.sample_rate, precache=True)

    # Load clean data and split it into train and val
    timit = TimitDataset(args.dset.timit, subset='train', sample_rate=args.sample_rate, with_path=False)
    
    len_train = int(len(timit) * (1 - args.dset.val_fraction))
    len_val = len(timit) - len_train
    timit_train, timit_val = random_split(timit, [len_train, len_val],
                                          generator=torch.Generator().manual_seed(args.random_seed))
    
    timit_train = prepare_mixture_set(timit_train, noises, dict(args.dset.mixture_train),
                                      random_seed=args.random_seed, crop_length=args.crop_length)
    timit_val   = prepare_mixture_set(timit_val, noises, dict(args.dset.mixture_val),
                                      random_seed=args.random_seed, crop_length=args.crop_length)    
    
    train_loader = DataLoader(timit_train, shuffle=True, batch_size=args.batch_size,
                              num_workers=args.dl_workers, drop_last=True)
    val_loader   = DataLoader(timit_val, batch_size=args.batch_size,
                              num_workers=args.dl_workers, drop_last=True)
    
    return train_loader, val_loader
예제 #3
0
DRONE_NOISE_DIR = '/jmain01/home/JAD007/txk02/aaa18-txk02/Datasets/noises-train-drones'
# fixed SNRs for validation set
TRAIN_SNRS = [-25, -20, -15, -10, -5]


TIMIT_DIR = PurePath('/jmain01/home/JAD007/txk02/aaa18-txk02/Datasets/TIMIT')
TIMIT_DIR_8kHZ = PurePath('/jmain01/home/JAD007/txk02/aaa18-txk02/Datasets/TIMIT_8kHZ')

# Reproducibility - fix all random seeds
seed_everything(SEED)

# Load noises, resample and save into the memory
noises = CachedWavSet(DRONE_NOISE_DIR, sample_rate=SAMPLE_RATE, precache=True)

# Load clean data and split it into train and val
timit = TimitDataset(TIMIT_DIR_8kHZ, subset='train', sample_rate=SAMPLE_RATE, with_path=False)
timit_train, timit_val = train_val_split(timit, val_fraction=0.1, random_seed=SEED)

# Training data mixes crops randomly on the fly with random SNR in range (effectively infinite training data)
# `repeat_factor=20` means that the dataset contains 20 copies of itself - it is the easiest way to make the epoch longer
timit_train = RandomMixtureSet(timit_train, noises, random_seed=SEED, snr_range=(-25, -5),
                               crop_length=CROP_LEN, repeat_factor=30)

# Validation data is fixed (for stability): mix every clean clip with all the noises in the folder
# Argument `mixtures_per_clean` regulates with how many different noise files each clean file will be mixed
timit_val = FixedMixtureSet(timit_val, noises, snrs=TRAIN_SNRS, random_seed=SEED,
                            mixtures_per_clean=5, crop_length=CROP_LEN)

NUM_WORKERS = 5
train_loader = DataLoader(timit_train, shuffle=True, batch_size=BATCH_SIZE,
                          num_workers=NUM_WORKERS, drop_last=True)
예제 #4
0
def download_datasets(args):
    logger.info('Loading datasets')
    TimitDataset.download(args.dset.timit, sample_rate=args.sample_rate)
    download_and_unpack_noises(args.dset.noises)
예제 #5
0
from asteroid.masknn import UNetGANGenerator, UNetGANDiscriminator
from asteroid.utils.notebook_utils import show_wav

sys.path.append('../egs')

from timit_drones.evaluate import evaluate_model

TIMIT_DIR_8kHZ ='/jmain01/home/JAD007/txk02/aaa18-txk02/Datasets/TIMIT_8kHZ' 
TEST_NOISE_DIR = '/jmain01/home/JAD007/txk02/aaa18-txk02/Datasets/noises-test-drones'
SAMPLE_RATE    = 8000
TEST_SNRS      = [-30, -25, -20, -15, -10, -5, 0]
SEED           = 42


timit_test_clean = TimitDataset(TIMIT_DIR_8kHZ, subset='test', sample_rate=SAMPLE_RATE, with_path=False)
timit_small = Subset(timit_test_clean, np.arange(len(timit_test_clean)//20))
noises_test = CachedWavSet(TEST_NOISE_DIR, sample_rate=SAMPLE_RATE, precache=True)


timit_test_small = FixedMixtureSet(timit_small, noises_test, snrs=TEST_SNRS, random_seed=SEED, with_snr=True)
timit_test = FixedMixtureSet(timit_test_clean, noises_test, snrs=TEST_SNRS, random_seed=SEED, with_snr=True)

torch.multiprocessing.set_sharing_strategy('file_system')

metrics_names = {
    'pesq': 'PESQ',
    'stoi': 'STOI',
    'si_sdr': 'SI-SDR',
}