Example #1
0
    format='%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',	
    datefmt='%Y-%m-%d:%H:%M:%S', level=logging.INFO) 

# make sure this is set to WHAM root directory
WHAM_ROOT = os.getenv("WHAM_ROOT")
NUM_WORKERS = multiprocessing.cpu_count() // 4
OUTPUT_DIR = os.path.expanduser('~/.nussl/recipes/ideal_ratio_mask/')
RESULTS_DIR = os.path.join(OUTPUT_DIR, 'results')

# APPROACH, KWARGS = 'psa', {'range_min': -np.inf, 'range_max':np.inf}
APPROACH, KWARGS = 'msa', {}

shutil.rmtree(os.path.join(RESULTS_DIR), ignore_errors=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

test_dataset = datasets.WHAM(WHAM_ROOT, sample_rate=8000, split='tt')


def separate_and_evaluate(item_):
    separator = separation.benchmark.IdealRatioMask(
        item_['mix'], item_['sources'], approach=APPROACH,
        mask_type='soft', **KWARGS)
    estimates = separator()

    evaluator = evaluation.BSSEvalScale(
        list(item_['sources'].values()), estimates, compute_permutation=True)
    scores = evaluator.evaluate()
    output_path = os.path.join(RESULTS_DIR, f"{item_['mix'].file_name}.json")
    with open(output_path, 'w') as f:
        json.dump(scores, f)
Example #2
0
            400)  # get 400 frame excerpts (3.2 seconds)
    ])
    return tfm


def cache_dataset(_dataset):
    cache_dataloader = torch.utils.data.DataLoader(_dataset,
                                                   num_workers=NUM_WORKERS,
                                                   batch_size=BATCH_SIZE)
    ml.train.cache_dataset(cache_dataloader)
    _dataset.cache_populated = True


tfm = construct_transforms(os.path.join(CACHE_ROOT, 'tr'))
dataset = datasets.WHAM(WHAM_ROOT,
                        split='tr',
                        transform=tfm,
                        cache_populated=CACHE_POPULATED)

tfm = construct_transforms(os.path.join(CACHE_ROOT, 'cv'))
val_dataset = datasets.WHAM(WHAM_ROOT,
                            split='cv',
                            transform=tfm,
                            cache_populated=CACHE_POPULATED)

if not CACHE_POPULATED:
    # cache datasets for speed
    cache_dataset(dataset)
    cache_dataset(val_dataset)

# ----------------------------------------------------
# -------------------- TRAINING ----------------------