format='%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d:%H:%M:%S', level=logging.INFO) # make sure this is set to WHAM root directory WHAM_ROOT = os.getenv("WHAM_ROOT") NUM_WORKERS = multiprocessing.cpu_count() // 4 OUTPUT_DIR = os.path.expanduser('~/.nussl/recipes/ideal_ratio_mask/') RESULTS_DIR = os.path.join(OUTPUT_DIR, 'results') # APPROACH, KWARGS = 'psa', {'range_min': -np.inf, 'range_max':np.inf} APPROACH, KWARGS = 'msa', {} shutil.rmtree(os.path.join(RESULTS_DIR), ignore_errors=True) os.makedirs(RESULTS_DIR, exist_ok=True) test_dataset = datasets.WHAM(WHAM_ROOT, sample_rate=8000, split='tt') def separate_and_evaluate(item_): separator = separation.benchmark.IdealRatioMask( item_['mix'], item_['sources'], approach=APPROACH, mask_type='soft', **KWARGS) estimates = separator() evaluator = evaluation.BSSEvalScale( list(item_['sources'].values()), estimates, compute_permutation=True) scores = evaluator.evaluate() output_path = os.path.join(RESULTS_DIR, f"{item_['mix'].file_name}.json") with open(output_path, 'w') as f: json.dump(scores, f)
400) # get 400 frame excerpts (3.2 seconds) ]) return tfm def cache_dataset(_dataset): cache_dataloader = torch.utils.data.DataLoader(_dataset, num_workers=NUM_WORKERS, batch_size=BATCH_SIZE) ml.train.cache_dataset(cache_dataloader) _dataset.cache_populated = True tfm = construct_transforms(os.path.join(CACHE_ROOT, 'tr')) dataset = datasets.WHAM(WHAM_ROOT, split='tr', transform=tfm, cache_populated=CACHE_POPULATED) tfm = construct_transforms(os.path.join(CACHE_ROOT, 'cv')) val_dataset = datasets.WHAM(WHAM_ROOT, split='cv', transform=tfm, cache_populated=CACHE_POPULATED) if not CACHE_POPULATED: # cache datasets for speed cache_dataset(dataset) cache_dataset(val_dataset) # ---------------------------------------------------- # -------------------- TRAINING ----------------------