def main(): TIMIT_CACHE_DIR = '/import/vision-eddydata/dm005_tmp/mixed_wavs_asteroid' train_snrs = [-25, -20, -15, -10, -5, 0, 5, 10, 15] test_snrs = [-30, -25, -20, -15, -10, -5, 0] timit_train_misc = TimitDataset.load_with_cache( '../../../datasets/TIMIT', '../../../datasets/noises-train', cache_dir=TIMIT_CACHE_DIR, snrs=train_snrs, root_seed=42, prefetch_mixtures=False, dset_name='train-misc', subset='train', track_duration=48000) train_set, val_set = train_val_split(timit_train_misc) BATCH_SIZE = 64 NUM_WORKERS = 10 train_loader = DataLoader( train_set, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, drop_last=True, ) val_loader = DataLoader( val_set, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, drop_last=True, ) unetgan_module = UNetGAN(lr_g=2e-4, lr_d=2e-4) trainer = Trainer( max_epochs=102, gpus=-1, accelerator='dp', resume_from_checkpoint= 'logs/lightning_logs/version_0/checkpoints/epoch=99-step=395190.ckpt') trainer.fit(unetgan_module, train_loader, val_loader)
def prepare_dataloaders(args): train_noise_dir = PurePath(args.dset.noises) / 'noises-train-drones' noises = CachedWavSet(train_noise_dir, sample_rate=args.sample_rate, precache=True) # Load clean data and split it into train and val timit = TimitDataset(args.dset.timit, subset='train', sample_rate=args.sample_rate, with_path=False) len_train = int(len(timit) * (1 - args.dset.val_fraction)) len_val = len(timit) - len_train timit_train, timit_val = random_split(timit, [len_train, len_val], generator=torch.Generator().manual_seed(args.random_seed)) timit_train = prepare_mixture_set(timit_train, noises, dict(args.dset.mixture_train), random_seed=args.random_seed, crop_length=args.crop_length) timit_val = prepare_mixture_set(timit_val, noises, dict(args.dset.mixture_val), random_seed=args.random_seed, crop_length=args.crop_length) train_loader = DataLoader(timit_train, shuffle=True, batch_size=args.batch_size, num_workers=args.dl_workers, drop_last=True) val_loader = DataLoader(timit_val, batch_size=args.batch_size, num_workers=args.dl_workers, drop_last=True) return train_loader, val_loader
DRONE_NOISE_DIR = '/jmain01/home/JAD007/txk02/aaa18-txk02/Datasets/noises-train-drones' # fixed SNRs for validation set TRAIN_SNRS = [-25, -20, -15, -10, -5] TIMIT_DIR = PurePath('/jmain01/home/JAD007/txk02/aaa18-txk02/Datasets/TIMIT') TIMIT_DIR_8kHZ = PurePath('/jmain01/home/JAD007/txk02/aaa18-txk02/Datasets/TIMIT_8kHZ') # Reproducibility - fix all random seeds seed_everything(SEED) # Load noises, resample and save into the memory noises = CachedWavSet(DRONE_NOISE_DIR, sample_rate=SAMPLE_RATE, precache=True) # Load clean data and split it into train and val timit = TimitDataset(TIMIT_DIR_8kHZ, subset='train', sample_rate=SAMPLE_RATE, with_path=False) timit_train, timit_val = train_val_split(timit, val_fraction=0.1, random_seed=SEED) # Training data mixes crops randomly on the fly with random SNR in range (effectively infinite training data) # `repeat_factor=20` means that the dataset contains 20 copies of itself - it is the easiest way to make the epoch longer timit_train = RandomMixtureSet(timit_train, noises, random_seed=SEED, snr_range=(-25, -5), crop_length=CROP_LEN, repeat_factor=30) # Validation data is fixed (for stability): mix every clean clip with all the noises in the folder # Argument `mixtures_per_clean` regulates with how many different noise files each clean file will be mixed timit_val = FixedMixtureSet(timit_val, noises, snrs=TRAIN_SNRS, random_seed=SEED, mixtures_per_clean=5, crop_length=CROP_LEN) NUM_WORKERS = 5 train_loader = DataLoader(timit_train, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, drop_last=True)
def download_datasets(args): logger.info('Loading datasets') TimitDataset.download(args.dset.timit, sample_rate=args.sample_rate) download_and_unpack_noises(args.dset.noises)
from asteroid.masknn import UNetGANGenerator, UNetGANDiscriminator from asteroid.utils.notebook_utils import show_wav sys.path.append('../egs') from timit_drones.evaluate import evaluate_model TIMIT_DIR_8kHZ ='/jmain01/home/JAD007/txk02/aaa18-txk02/Datasets/TIMIT_8kHZ' TEST_NOISE_DIR = '/jmain01/home/JAD007/txk02/aaa18-txk02/Datasets/noises-test-drones' SAMPLE_RATE = 8000 TEST_SNRS = [-30, -25, -20, -15, -10, -5, 0] SEED = 42 timit_test_clean = TimitDataset(TIMIT_DIR_8kHZ, subset='test', sample_rate=SAMPLE_RATE, with_path=False) timit_small = Subset(timit_test_clean, np.arange(len(timit_test_clean)//20)) noises_test = CachedWavSet(TEST_NOISE_DIR, sample_rate=SAMPLE_RATE, precache=True) timit_test_small = FixedMixtureSet(timit_small, noises_test, snrs=TEST_SNRS, random_seed=SEED, with_snr=True) timit_test = FixedMixtureSet(timit_test_clean, noises_test, snrs=TEST_SNRS, random_seed=SEED, with_snr=True) torch.multiprocessing.set_sharing_strategy('file_system') metrics_names = { 'pesq': 'PESQ', 'stoi': 'STOI', 'si_sdr': 'SI-SDR', }