Beispiel #1
0
def main(_run, batch_size, datasets, debug, experiment_dir, database_json):
    experiment_dir = Path(experiment_dir)

    if dlp_mpi.IS_MASTER:
        sacred.commands.print_config(_run)

    model = get_model()
    db = JsonDatabase(json_path=database_json)

    model.eval()
    with torch.no_grad():
        summary = defaultdict(dict)
        for dataset in datasets:
            iterable = prepare_iterable(
                db,
                dataset,
                batch_size,
                return_keys=None,
                prefetch=False,
            )

            for batch in dlp_mpi.split_managed(iterable,
                                               is_indexable=False,
                                               progress_bar=True,
                                               allow_single_worker=debug):
                entry = dict()
                model_output = model(pt.data.example_to_device(batch))

                example_id = batch['example_id'][0]
                s = batch['s'][0]
                Y = batch['Y'][0]
                mask = model_output[0].numpy()

                Z = mask * Y[:, None, :]
                z = istft(einops.rearrange(Z, "t k f -> k t f"),
                          size=512,
                          shift=128)

                s = s[:, :z.shape[1]]
                z = z[:, :s.shape[1]]
                entry['metrics'] \
                    = pb_bss.evaluation.OutputMetrics(speech_prediction=z,
                                                      speech_source=s).as_dict()

        summary[dataset][example_id] = entry

    summary_list = dlp_mpi.gather(summary, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        print(f'len(summary_list): {len(summary_list)}')
        for partial_summary in summary_list:
            for dataset, values in partial_summary.items():
                summary[dataset].update(values)

        for dataset, values in summary.items():
            print(f'{dataset}: {len(values)}')

        result_json_path = experiment_dir / 'result.json'
        print(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)
Beispiel #2
0
def speedup():
    print(f'speedup test {RANK}')
    examples = list(range(4))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    sleep_time = 0.1

    def bar(i):
        time.sleep(sleep_time)
        # print(f'Callback from {RANK}')
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)

    start = time.perf_counter()
    for i in dlp_mpi.map_unordered(bar, examples):
        # print(f'Loop body from {RANK}')
        assert dlp_mpi.RANK in [0], (dlp_mpi.RANK, dlp_mpi.SIZE)
    elapsed = time.perf_counter() - start

    serial_time = sleep_time * len(examples)

    # Two workers, one manager (3 mpi processes) reduce the time by 0.5
    # Consider some python overhead
    assert elapsed < 0.6 * serial_time, (elapsed, serial_time)
    assert elapsed >= 0.5 * serial_time, (elapsed, serial_time)
Beispiel #3
0
def worker_fails():
    print(f'worker_fails test {RANK}')

    examples = list(range(5))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    processed = []
    try:
        dlp_mpi.barrier()
        if RANK == 2:
            # Delay rank 2, this ensures that rank 1 gets the first example
            # Does no longer work, becasue in split_managed is COMM.Clone
            # used.
            time.sleep(0.1)
        for i in dlp_mpi.split_managed(examples, progress_bar=False):
            processed.append(i)
            if RANK == 1:
                print(f'let {RANK} fail for data {i}')
                raise ValueError('failed')
            assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)
    except ValueError:
        assert RANK in [1], RANK
        assert processed in [[0], [1]], processed
    except AssertionError:
        assert RANK in [0], RANK
        assert processed == [], processed
    else:
        assert RANK in [2], RANK
        assert processed == [1, 2, 3, 4], processed
Beispiel #4
0
def overhead():
    """

    This test shows the overhead of map_unordered.
    A simple for loop with map cam process around 1 000 000 examples per
    second. When using map_unordered, obviously the number should decrease.

    When your code processes less than 1000 examples per second, you can expect
    a gain from map_unordered. When you process serial more than 10000 examples
    per second, it is unlikely to get a gain from map_unordered.
    Thing about chunking to get less than 100 example per second.
    """
    print(f'executable test {RANK}')

    examples = list(range(10000))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    total = 0
    # def bar(i):
    #     nonlocal total
    #     total += i
    #     return i

    start = time.perf_counter()
    for i in dlp_mpi.split_managed(examples, progress_bar=False):
        total += i
    elapsed = time.perf_counter() - start

    if RANK == 0:
        assert total == 0, total
    elif RANK == 1:
        assert total > 10_000_000, total
    elif RANK == 2:
        assert total > 10_000_000, total
    else:
        raise ValueError(RANK)

    time_per_example = elapsed / 1000
    mpi_examples_per_second = 1 / time_per_example

    assert mpi_examples_per_second >= 10_000, mpi_examples_per_second
    assert mpi_examples_per_second <= 300_000, mpi_examples_per_second

    print('split_managed examples/second =', mpi_examples_per_second)

    start = time.perf_counter()
    for i in examples:
        total += i
    elapsed = time.perf_counter() - start

    time_per_example = elapsed / 1000
    py_examples_per_second = 1 / time_per_example

    assert py_examples_per_second >= 250_000, py_examples_per_second
    assert py_examples_per_second <= 9_000_000, py_examples_per_second
Beispiel #5
0
def pbar():
    print(f'executable test {RANK}')

    examples = list(range(5))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    def bar(i):
        time.sleep(0.04)
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)

    class MockPbar:
        call_history = []

        def __init__(self):
            self.i = 0

        def set_description(self, text):
            self.call_history.append(text)

        def update(self, inc=1):
            self.i += 1
            self.call_history.append(f'update {self.i}')

    import contextlib

    @contextlib.contextmanager
    def mock_pbar(total, disable):
        assert disable is False, disable
        yield MockPbar()

    import mock

    with mock.patch('tqdm.tqdm', mock_pbar):

        dlp_mpi.barrier()
        if RANK == 2:
            time.sleep(0.02)

        for i in dlp_mpi.map_unordered(
                bar,
                examples,
                progress_bar=True,
        ):
            assert dlp_mpi.RANK in [0], (dlp_mpi.RANK, dlp_mpi.SIZE)

    if RANK == 0:
        assert MockPbar.call_history == [
            'busy: 2', 'update 1', 'update 2', 'update 3', 'update 4',
            'busy: 1', 'update 5', 'busy: 0'
        ], MockPbar.call_history
    else:
        assert MockPbar.call_history == [], MockPbar.call_history
Beispiel #6
0
def executable():
    print(f'executable test {RANK}')

    examples = list(range(5))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    for i in dlp_mpi.split_managed(examples, progress_bar=False):
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)
Beispiel #7
0
def cross_communication():
    print(f'cross_communication test {RANK}')

    examples = list(range(5))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    dlp_mpi.barrier()
    if RANK == 1:
        time.sleep(0.1)
    elif RANK == 2:
        pass

    results = []
    for i in dlp_mpi.split_managed(examples, progress_bar=False):
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)

        if RANK == 1:
            results.append(i)
        elif RANK == 2:
            results.append(i)
            time.sleep(0.2)

    if RANK == 1:
        assert results in [[0, 2, 3, 4], [1, 2, 3, 4]], results
    elif RANK == 2:
        assert results in [[1], [0]], results

    for i in dlp_mpi.split_managed(examples, progress_bar=False):
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)
        if RANK == 1:
            results.append(i)
            time.sleep(0.001)
        elif RANK == 2:
            results.append(i)
            time.sleep(0.2)

    if RANK == 1:
        assert results in [
            [0, 2, 3, 4, 0, 2, 3, 4],
            [0, 2, 3, 4, 1, 2, 3, 4],
            [1, 2, 3, 4, 0, 2, 3, 4],
            [1, 2, 3, 4, 1, 2, 3, 4],
        ], results
    elif RANK == 2:
        assert results in [
            [1, 1],
            [1, 0],
            [0, 1],
            [0, 0],
        ], results
def main(_run, out):
    if dlp_mpi.IS_MASTER:
        from sacred.commands import print_config
        print_config(_run)

    ds = get_dataset()

    data = []

    for ex in dlp_mpi.split_managed(ds.sort(), allow_single_worker=True):
        for prediction in [
                'source',
                'early_0',
                'early_1',
                'image_0',
                'image_1',
                'image_0_noise',
                'image_1_noise',
        ]:
            for source in [
                    'source',
                    'early_0',
                    'early_1',
                    'image_0',
                    'image_1',
                    'image_0_noise',
                    'image_1_noise',
            ]:
                scores = get_scores(ex, prediction=prediction, source=source)
                for score_name, score_value in scores.items():
                    data.append(
                        dict(
                            score_name=score_name,
                            prediction=prediction,
                            source=source,
                            example_id=ex['example_id'],
                            value=score_value,
                        ))

    data = dlp_mpi.gather(data)

    if dlp_mpi.IS_MASTER:
        data = [entry for worker_data in data for entry in worker_data]

        if out is not None:
            assert isinstance(out, str), out
            assert out.endswith('.json'), out
            print(f'Write details to {out}.')
            dump_json(data, out)

        summary(data)
Beispiel #9
0
def executable():
    print(f'executable test {RANK}')

    examples = list(range(5))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    def bar(i):
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)

    for i in dlp_mpi.map_unordered(bar, examples):
        assert dlp_mpi.RANK in [0], (dlp_mpi.RANK, dlp_mpi.SIZE)
Beispiel #10
0
def overhead():
    """

    This test shows the overhead of map_unordered.
    A simple for loop with map cam process around 1 000 000 examples per
    second. When using map_unordered, obviously the number should decrease.

    When your code processes less than 1000 examples per second, you can expect
    a gain from map_unordered. When you process serial more than 10000 examples
    per second, it is unlikely to get a gain from map_unordered.
    Thing about chunking to get less than 100 example per second.
    """
    print(f'executable test {RANK}')

    examples = list(range(10000))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    total = 0

    def bar(i):
        nonlocal total
        total += i
        return i

    start = time.perf_counter()
    for i in dlp_mpi.map_unordered(bar, examples):
        total += i
    elapsed = time.perf_counter() - start

    time_per_example = elapsed / 1000
    mpi_examples_per_second = 1 / time_per_example

    assert mpi_examples_per_second >= 10_000, mpi_examples_per_second
    assert mpi_examples_per_second <= 300_000, mpi_examples_per_second

    start = time.perf_counter()
    for i in map(bar, examples):
        total += i
    elapsed = time.perf_counter() - start

    time_per_example = elapsed / 1000
    py_examples_per_second = 1 / time_per_example

    assert py_examples_per_second >= 250_000, py_examples_per_second
    assert py_examples_per_second <= 9_000_000, py_examples_per_second
Beispiel #11
0
def evaluate_model(dataset, model, get_sad_fn,
                   get_target_fn=lambda x: x['activation'],
                   num_thresholds=201, buffer_zone=0.5,
                   is_indexable=True, allow_single_worker=True,
                   sample_rate=8000):

    tp_fp_tn_fn = np.zeros((num_thresholds, 4), dtype=int)

    import dlp_mpi
    for example in dlp_mpi.split_managed(
        dataset, is_indexable=is_indexable,
        allow_single_worker=allow_single_worker,
    ):
        target = get_target_fn(example)
        adjusted_target = adjust_annotation_fn(
            target, buffer_zone=buffer_zone,
            sample_rate=sample_rate
        )
        model_out = model(example)
        for idx, th in enumerate(np.linspace(0, 1, num_thresholds)):
            th = np.round(th, 2)
            sad = get_sad_fn(model_out, th, example)
            out = get_tp_fp_tn_fn(
                adjusted_target, sad,
                sample_rate=sample_rate, adjust_annotation=False
            )
            tp_fp_tn_fn[idx] = [tp_fp_tn_fn[idx][idy] + o for idy, o in
                                enumerate(out)]

    dlp_mpi.barrier()
    tp_fp_tn_fn_gather = dlp_mpi.gather(tp_fp_tn_fn, root=dlp_mpi.MASTER)
    if dlp_mpi.IS_MASTER:
        tp_fp_tn_fn = np.zeros((num_thresholds, 4), dtype=int)
        for array in tp_fp_tn_fn_gather:
            tp_fp_tn_fn += array
    else:
        tp_fp_tn_fn = None
    return tp_fp_tn_fn
Beispiel #12
0
def worker_fails():
    print(f'executable test {RANK}')

    examples = list(range(5))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    def bar(i):
        if RANK == 1:
            print(f'let {RANK} fail for data {i}')
            raise ValueError('failed')
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)
        return i, RANK

    processed = []
    try:
        dlp_mpi.barrier()
        if RANK == 2:
            # Delay rank 2, this ensures that rank 1 gets the first example
            time.sleep(0.1)
        for i, worker_rank in dlp_mpi.map_unordered(bar, examples):
            print(
                f'Loop body from {RANK} for data {i} that was processed by {worker_rank}'
            )
            assert dlp_mpi.RANK in [0], (dlp_mpi.RANK, dlp_mpi.SIZE)
            processed.append(i)
    except ValueError:
        assert RANK in [1], RANK
    except AssertionError:
        assert RANK in [0], RANK
        # Example zero failed for worker 1, but the master process fails at the
        # end of the for loop. So examples 1 to 4 are processed
        assert processed == [1, 2, 3, 4], processed
    else:
        assert RANK in [2], RANK
Beispiel #13
0
def bottleneck():
    print(f'bottleneck test {RANK}')
    examples = list(range(4))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    sleep_time = 0.1

    def bar(i):
        # print(f'Callback from {RANK} for data {i}')
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)
        return i, RANK

    start = time.perf_counter()
    for i, worker_rank in dlp_mpi.map_unordered(bar, examples):
        time.sleep(sleep_time)
        # print(f'Loop body from {RANK} for data {i} that was processed by {worker_rank}')
        assert dlp_mpi.RANK in [0], (dlp_mpi.RANK, dlp_mpi.SIZE)
    elapsed = time.perf_counter() - start

    # Two workers, one manager (3 mpi processes) would reduce the time by 0.5
    # but the load is in the loop body -> not speedup
    # Consider some python overhead
    if dlp_mpi.IS_MASTER:
        serial_time = sleep_time * len(examples)
        assert elapsed < 1.1 * serial_time, (elapsed, serial_time)
        assert elapsed >= 0.9 * serial_time, (elapsed, serial_time)
    else:
        # The worker finsies while the master still process the last two
        # examples. One worker finished two steps earlier, the other 1 step
        # earlier
        serial_time_high = sleep_time * (len(examples) - 1)
        serial_time_low = sleep_time * (len(examples) - 2)
        assert elapsed < 1.1 * serial_time_high, (elapsed, serial_time_high)
        assert elapsed >= 0.9 * serial_time_low, (elapsed, serial_time_low)
Beispiel #14
0
def main(_run, datasets, debug, experiment_dir, export_audio, sample_rate,
         _log, database_json, oracle_num_spk, max_iterations):
    experiment_dir = Path(experiment_dir)

    if mpi.IS_MASTER:
        sacred.commands.print_config(_run)
        dump_config_and_makefile()

    model = get_model()
    db = JsonDatabase(database_json)

    model.eval()
    with torch.no_grad():
        summary = defaultdict(dict)
        for dataset in datasets:
            iterable = prepare_iterable(
                db,
                dataset,
                1,
                chunk_size=-1,
                prefetch=False,
                shuffle=False,
                iterator_slice=slice(mpi.RANK, 20 if debug else None,
                                     mpi.SIZE),
            )

            if export_audio:
                (experiment_dir / 'audio' / dataset).mkdir(parents=True,
                                                           exist_ok=True)

            for batch in tqdm(
                    iterable,
                    total=len(iterable),
                    disable=not mpi.IS_MASTER,
                    desc=dataset,
            ):
                example_id = batch['example_id'][0]
                summary[dataset][example_id] = entry = dict()
                oracle_speaker_count = \
                    entry['oracle_speaker_count'] = batch['s'][0].shape[0]

                try:
                    model_output = model.decode(
                        pt.data.example_to_device(batch),
                        max_iterations=max_iterations,
                        oracle_num_speakers=oracle_speaker_count
                        if oracle_num_spk else None)

                    # Bring to numpy float64 for evaluation metrics computation
                    s = batch['s'][0].astype(np.float64)
                    z = model_output['out'][0].cpu().numpy().astype(np.float64)

                    estimated_speaker_count = \
                        entry['estimated_speaker_count'] = z.shape[0]
                    entry['source_counting_accuracy'] = \
                        estimated_speaker_count == oracle_speaker_count

                    if oracle_speaker_count == estimated_speaker_count:
                        # These evaluations don't work if the number of
                        # speakers in s and z don't match
                        entry['mir_eval'] = pb_bss.evaluation.mir_eval_sources(
                            s, z, return_dict=True)

                        # Get the correct order for si_sdr and saving
                        z = z[entry['mir_eval']['selection']]

                        entry['si_sdr'] = pb_bss.evaluation.si_sdr(s, z)
                    else:
                        warnings.warn(
                            'The number of speakers is estimated incorrectly '
                            'for some examples! The calculated SDR values '
                            'might not be representative!')

                    if export_audio:
                        entry['audio_path'] = batch['audio_path']
                        entry['audio_path'].setdefault('estimated', [])

                        for k, audio in enumerate(z):
                            audio_path = (experiment_dir / 'audio' / dataset /
                                          f'{example_id}_{k}.wav')
                            pb.io.dump_audio(audio,
                                             audio_path,
                                             sample_rate=sample_rate)
                            entry['audio_path']['estimated'].append(audio_path)
                except:
                    _log.error(f'Exception was raised in example with ID '
                               f'"{example_id}"')
                    raise

    summary_list = mpi.gather(summary, root=mpi.MASTER)

    if mpi.IS_MASTER:
        # Combine all summaries to one
        for partial_summary in summary_list:
            for dataset, values in partial_summary.items():
                summary[dataset].update(values)

        for dataset, values in summary.items():
            _log.info(f'{dataset}: {len(values)}')

        # Write summary to JSON
        result_json_path = experiment_dir / 'result.json'
        _log.info(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)

        # Compute means for some metrics
        mean_keys = [
            'mir_eval.sdr', 'mir_eval.sar', 'mir_eval.sir', 'si_sdr',
            'source_counting_accuracy'
        ]
        means = {}
        for dataset, dataset_results in summary.items():
            means[dataset] = {}
            flattened = {
                k: pb.utils.nested.flatten(v)
                for k, v in dataset_results.items()
            }
            for mean_key in mean_keys:
                try:
                    means[dataset][mean_key] = np.mean(
                        np.array([v[mean_key] for v in flattened.values()]))
                except KeyError:
                    warnings.warn(f'Couldn\'t compute mean for {mean_key}.')
            means[dataset] = pb.utils.nested.deflatten(means[dataset])

        mean_json_path = experiment_dir / 'means.json'
        _log.info(f'Exporting means: {mean_json_path}')
        pb.io.dump_json(means, mean_json_path)

        _log.info('Resulting means:')

        pprint(means)
Beispiel #15
0
def main(_run, datasets, debug, experiment_dir, dump_audio, sample_rate, _log,
         database_json):
    experiment_dir = Path(experiment_dir)

    if dlp_mpi.IS_MASTER:
        sacred.commands.print_config(_run)
        dump_config_and_makefile()

    model = get_model()
    db = JsonDatabase(database_json)

    model.eval()
    results = defaultdict(dict)
    with torch.no_grad():
        for dataset in datasets:
            iterable = prepare_dataset(
                db,
                dataset,
                1,
                chunk_size=-1,
                prefetch=False,
                shuffle=False,
                dataset_slice=slice(dlp_mpi.RANK, 20 if debug else None,
                                    dlp_mpi.SIZE),
            )

            if dump_audio:
                (experiment_dir / 'audio' / dataset).mkdir(parents=True,
                                                           exist_ok=True)

            for batch in tqdm(iterable,
                              total=len(iterable),
                              disable=not dlp_mpi.IS_MASTER):
                example_id = batch['example_id'][0]
                results[dataset][example_id] = entry = dict()

                try:
                    model_output = model(model.example_to_device(batch))

                    # Bring to numpy float64 for evaluation metrics computation
                    observation = batch['y'][0].astype(np.float64)[None, ]
                    speech_prediction = (
                        model_output['out'][0].cpu().numpy().astype(
                            np.float64))
                    speech_source = batch['s'][0].astype(np.float64)

                    input_metrics = pb_bss.evaluation.InputMetrics(
                        observation=observation,
                        speech_source=speech_source,
                        sample_rate=sample_rate,
                        enable_si_sdr=True,
                    )

                    output_metrics = pb_bss.evaluation.OutputMetrics(
                        speech_prediction=speech_prediction,
                        speech_source=speech_source,
                        sample_rate=sample_rate,
                        enable_si_sdr=True,
                    )

                    # Select the metrics to compute
                    entry['input'] = dict(
                        mir_eval=input_metrics.mir_eval,
                        si_sdr=input_metrics.si_sdr,
                        # TODO: stoi fails with short speech segments (https://github.com/mpariente/pystoi/issues/21)
                        # stoi=input_metrics.stoi,
                        # TODO: pesq creates "Processing error" messages
                        # pesq=input_metrics.pesq,
                    )

                    # Remove selection from mir_eval dict to enable
                    # recursive calculation of improvement
                    entry['output'] = dict(
                        mir_eval={
                            k: v
                            for k, v in output_metrics.mir_eval.items()
                            if k != 'selection'
                        },
                        si_sdr=output_metrics.si_sdr,
                        # stoi=output_metrics.stoi,
                        # pesq=output_metrics.pesq,
                    )

                    entry['improvement'] = pb.utils.nested.nested_op(
                        operator.sub,
                        entry['output'],
                        entry['input'],
                    )
                    entry['selection'] = output_metrics.mir_eval['selection']

                    if dump_audio:
                        entry['audio_path'] = batch['audio_path']
                        entry['audio_path']['estimated'] = []
                        for k, audio in enumerate(speech_prediction):
                            audio_path = (experiment_dir / 'audio' / dataset /
                                          f'{example_id}_{k}.wav')
                            pb.io.dump_audio(audio,
                                             audio_path,
                                             sample_rate=sample_rate)
                            entry['audio_path']['estimated'].append(audio_path)
                except:
                    _log.error(f'Exception was raised in example with ID '
                               f'"{example_id}"')
                    raise

    results = dlp_mpi.gather(results, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        # Combine all results to one. This function raises an exception if it
        # finds duplicate keys
        results = pb.utils.nested.nested_merge(*results)

        for dataset, values in results.items():
            _log.info(f'{dataset}: {len(values)}')

        # Write results to JSON
        result_json_path = experiment_dir / 'result.json'
        _log.info(f"Exporting result: {result_json_path}")
        pb.io.dump_json(results, result_json_path)

        # Compute means for some metrics
        means = compute_means(results)
        mean_json_path = experiment_dir / 'means.json'
        _log.info(f'Exporting means: {mean_json_path}')
        pb.io.dump_json(means, mean_json_path)

        _log.info('Resulting means:')

        pprint(means)
Beispiel #16
0
def main(_run, batch_size, datasets, debug, experiment_dir, database_json,
         _log):
    experiment_dir = Path(experiment_dir)

    if dlp_mpi.IS_MASTER:
        sacred.commands.print_config(_run)

    model = get_model()
    db = JsonDatabase(json_path=database_json)

    model.eval()
    with torch.no_grad():
        summary = defaultdict(dict)
        for dataset_name in datasets:
            dataset = prepare_dataset(db,
                                      dataset_name,
                                      batch_size,
                                      return_keys=None,
                                      prefetch=False,
                                      shuffle=False)

            for batch in dlp_mpi.split_managed(dataset,
                                               is_indexable=True,
                                               progress_bar=True,
                                               allow_single_worker=debug):
                entry = dict()
                model_output = model(model.example_to_device(batch))

                example_id = batch['example_id'][0]
                s = batch['s'][0]
                Y = batch['Y'][0]
                mask = model_output[0].numpy()

                Z = mask * Y[:, None, :]
                z = istft(einops.rearrange(Z, "t k f -> k t f"),
                          size=512,
                          shift=128)

                s = s[:, :z.shape[1]]
                z = z[:, :s.shape[1]]

                input_metrics = pb_bss.evaluation.InputMetrics(
                    observation=batch['y'][0][None, :],
                    speech_source=s,
                    sample_rate=8000,
                    enable_si_sdr=False,
                )

                output_metrics = pb_bss.evaluation.OutputMetrics(
                    speech_prediction=z,
                    speech_source=s,
                    sample_rate=8000,
                    enable_si_sdr=False,
                )
                entry['input'] = dict(mir_eval=input_metrics.mir_eval, )
                entry['output'] = dict(mir_eval={
                    k: v
                    for k, v in output_metrics.mir_eval.items()
                    if k != 'selection'
                }, )

                entry['improvement'] = pb.utils.nested.nested_op(
                    operator.sub,
                    entry['output'],
                    entry['input'],
                )
                entry['selection'] = output_metrics.mir_eval['selection']

                summary[dataset][example_id] = entry

    summary_list = dlp_mpi.gather(summary, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        _log.info(f'len(summary_list): {len(summary_list)}')
        summary = pb.utils.nested.nested_merge(*summary_list)

        for dataset, values in summary.items():
            _log.info(f'{dataset}: {len(values)}')
            assert len(values) == len(
                db.get_dataset(dataset)
            ), 'Number of results needs to match length of dataset!'
        result_json_path = experiment_dir / 'result.json'
        _log.info(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)

        # Compute and save mean of metrics
        means = compute_means(summary)
        mean_json_path = experiment_dir / 'means.json'
        _log.info(f"Saving means to: {mean_json_path}")
        pb.io.dump_json(means, mean_json_path)
def main(
    _run,
    out,
    mask_estimator,
    Observation,
    beamformer,
    postfilter,
    normalize_audio=True,
):
    if dlp_mpi.IS_MASTER:
        from sacred.commands import print_config
        print_config(_run)

    ds = get_dataset()

    data = []

    out = Path(out)

    for ex in dlp_mpi.split_managed(ds.sort(), allow_single_worker=True):

        if mask_estimator is None:
            mask = None
        elif mask_estimator == 'cacgmm':
            mask = get_mask_from_cacgmm(ex)
        else:
            mask = get_mask_from_oracle(ex, mask_estimator)

        metric, score = get_scores(
            ex,
            mask,
            Observation=Observation,
            beamformer=beamformer,
            postfilter=postfilter,
        )

        est0, est1 = metric.speech_prediction_selection
        dump_audio(est0,
                   out / ex['dataset'] / f"{ex['example_id']}_0.wav",
                   normalize=normalize_audio)
        dump_audio(est1,
                   out / ex['dataset'] / f"{ex['example_id']}_1.wav",
                   normalize=normalize_audio)

        data.append(
            dict(
                example_id=ex['example_id'],
                value=score,
                dataset=ex['dataset'],
            ))

        # print(score, repr(score))

    data = dlp_mpi.gather(data)

    if dlp_mpi.IS_MASTER:
        data = [entry for worker_data in data for entry in worker_data]

        data = {  # itertools.groupby expect an order
            dataset: list(subset)
            for dataset, subset in from_list(data).groupby(
                lambda ex: ex['dataset']).items()
        }

        for dataset, sub_data in data.items():
            print(f'Write details to {out}.')
            dump_json(sub_data, out / f'{dataset}_scores.json')

        for dataset, sub_data in data.items():
            summary = {}
            for k in sub_data[0]['value'].keys():
                m = np.mean([d['value'][k] for d in sub_data])
                print(dataset, k, m)
                summary[k] = m
            dump_json(summary, out / f'{dataset}_summary.json')
Beispiel #18
0
def main(_run, datasets, debug, experiment_dir, export_audio,
         sample_rate, _log, database_json):
    experiment_dir = Path(experiment_dir)

    if mpi.IS_MASTER:
        sacred.commands.print_config(_run)
        dump_config_and_makefile()

    model = get_model()
    db = JsonDatabase(database_json)

    model.eval()
    with torch.no_grad():
        summary = defaultdict(dict)
        for dataset in datasets:
            iterable = prepare_iterable(
                db, dataset, 1,
                chunk_size=-1,
                prefetch=False,
                iterator_slice=slice(mpi.RANK, 20 if debug else None, mpi.SIZE),
            )

            if export_audio:
                (experiment_dir / 'audio' / dataset).mkdir(parents=True, exist_ok=True)

            for batch in tqdm(iterable, total=len(iterable), disable=not mpi.IS_MASTER):
                example_id = batch['example_id'][0]
                summary[dataset][example_id] = entry = dict()

                try:
                    model_output = model(pt.data.example_to_device(batch))

                    # Bring to numpy float64 for evaluation metrics computation
                    s = batch['s'][0].astype(np.float64)
                    z = model_output['out'][0].cpu().numpy().astype(np.float64)

                    entry['mir_eval'] \
                        = pb_bss.evaluation.mir_eval_sources(s, z, return_dict=True)

                    # Get the correct order for si_sdr and saving
                    z = z[entry['mir_eval']['selection']]

                    entry['si_sdr'] = pb_bss.evaluation.si_sdr(s, z)
                    # entry['stoi'] = pb_bss.evaluation.stoi(s, z, sample_rate)
                    # entry['pesq'] = pb_bss.evaluation.pesq(s, z, sample_rate)

                    if export_audio:
                        entry['audio_path'] = batch['audio_path']
                        for k, audio in enumerate(z):
                            audio_path = experiment_dir / 'audio' / dataset / f'{example_id}_{k}.wav'
                            pb.io.dump_audio(audio, audio_path, sample_rate=sample_rate)
                            entry['audio_path'].setdefault('estimated', []).append(audio_path)
                except:
                    _log.error(f'Exception was raised in example with ID "{example_id}"')
                    raise

    summary_list = mpi.gather(summary, root=mpi.MASTER)

    if mpi.IS_MASTER:
        # Combine all summaries to one
        for partial_summary in summary_list:
            for dataset, values in partial_summary.items():
                summary[dataset].update(values)

        for dataset, values in summary.items():
            _log.info(f'{dataset}: {len(values)}')

        # Write summary to JSON
        result_json_path = experiment_dir / 'result.json'
        _log.info(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)

        # Compute means for some metrics
        mean_keys = ['mir_eval.sdr', 'mir_eval.sar', 'mir_eval.sir', 'si_sdr']
        means = {}
        for dataset, dataset_results in summary.items():
            means[dataset] = {}
            flattened = {
                k: pb.utils.nested.flatten(v) for k, v in
                dataset_results.items()
            }
            for mean_key in mean_keys:
                means[dataset][mean_key] = np.mean(np.array([
                    v[mean_key] for v in flattened.values()
                ]))
            means[dataset] = pb.utils.nested.deflatten(means[dataset])

        mean_json_path = experiment_dir / 'means.json'
        _log.info(f'Exporting means: {mean_json_path}')
        pb.io.dump_json(means, mean_json_path)

        _log.info('Resulting means:')

        pprint(means)