Example #1
0
def validate(_config):
    import numpy as np
    import os
    from padertorch.contrib.jensheit.evaluation import evaluate_masks
    from functools import partial
    from paderbox.io import dump_json
    from concurrent.futures import ThreadPoolExecutor
    assert len(ex.current_run.observers) == 1, (
        'FileObserver` missing. Add a `FileObserver` with `-F foo/bar/`.')
    storage_dir = Path(ex.current_run.observers[0].basedir)
    assert not (storage_dir / 'results.json').exists(), (
        f'model_dir has already bin evaluatet, {storage_dir}')
    trainer, provider = initialize_trainer_provider(task='validate')
    trainer.model.cpu()
    eval_iterator = provider.get_eval_iterator()
    evaluation_json = dict(snr=dict(), pesq=dict())
    provider.opts.multichannel = True
    batch_size = 1
    provider.opts.batch_size = batch_size
    with ThreadPoolExecutor(os.cpu_count()) as executor:
        for example_id, snr, pesq in executor.map(
                partial(evaluate_masks,
                        model=trainer.model,
                        transform=provider.transformer.stft), eval_iterator):
            evaluation_json['snr'][example_id] = snr
            evaluation_json['pesq'][example_id] = pesq
    evaluation_json['pesq_mean'] = np.mean(
        [value for value in evaluation_json['pesq'].values()])
    evaluation_json['snr'] = np.mean(
        [value for value in evaluation_json['snr'].values()])
    dump_json(evaluation_json, storage_dir / 'results.json')
Example #2
0
def initialize_trainer_provider(task, trainer_opts, provider_opts, _run):

    storage_dir = Path(trainer_opts['storage_dir'])
    if (storage_dir / 'init.json').exists():
        assert task in ['restart', 'validate'], task
    elif task in ['train', 'create_checkpoint']:
        dump_json(
            dict(trainer_opts=recursive_class_to_str(trainer_opts),
                 provider_opts=recursive_class_to_str(provider_opts)),
            storage_dir / 'init.json')
    else:
        raise ValueError(task, storage_dir)
    sacred.commands.print_config(_run)

    trainer = Trainer.from_config(trainer_opts)
    assert isinstance(trainer, Trainer)
    provider = config_to_instance(provider_opts)
    return trainer, provider
Example #3
0
def create_json(database_path, json_path):
    database_path = Path(database_path)
    rir_root = database_path.joinpath('rirs/')
    setups = load_json(Path(database_path).joinpath('setups.json'))
    simulation_descriptions = \
        load_json(Path(database_path).joinpath('simulation_descriptions.json'))
    for scenario in simulation_descriptions.values():
        for example_id, example in scenario.items():
            for node_id, sro in example['sro'].items():
                if isinstance(sro, str):
                    example['sro'][node_id] = database_path.joinpath(sro)
            example['node__position'] = setups[example_id]['node_position']
            example['node_orientation'] = \
                setups[example_id]['node_orientation']
            example['environment'] = setups[example_id]['environment']
            example['src_diary'] = [
                complete_source_information(source, example_id, setups,
                                            rir_root)
                for source in example['src_diary']
            ]
    db = {'datasets': simulation_descriptions}
    dump_json(db, json_path, sort_keys=False)
Example #4
0
def main(_run, exp_dir, storage_dir, database_json, ckpt_name, num_workers,
         batch_size, max_padding_rate, device):
    commands.print_config(_run)

    exp_dir = Path(exp_dir)
    storage_dir = Path(storage_dir)

    config = load_json(exp_dir / 'config.json')

    model = Model.from_storage_dir(exp_dir,
                                   consider_mpi=True,
                                   checkpoint_name=ckpt_name)
    model.to(device)
    model.eval()

    _, validation_data, test_data = get_datasets(
        database_json=database_json,
        min_signal_length=1.5,
        audio_reader=config['audio_reader'],
        stft=config['stft'],
        num_workers=num_workers,
        batch_size=batch_size,
        max_padding_rate=max_padding_rate,
        storage_dir=exp_dir,
    )

    outputs = []
    with torch.no_grad():
        for example in tqdm(validation_data):
            example = model.example_to_device(example, device)
            (y, seq_len), _ = model(example)
            y = Mean(axis=-1)(y, seq_len)
            outputs.append((
                y.cpu().detach().numpy(),
                example['events'].cpu().detach().numpy(),
            ))

    scores, targets = list(zip(*outputs))
    scores = np.concatenate(scores)
    targets = np.concatenate(targets)
    thresholds, f1 = instance_based.get_optimal_thresholds(targets,
                                                           scores,
                                                           metric='f1')
    decisions = scores > thresholds
    f1, p, r = instance_based.fscore(targets, decisions, event_wise=True)
    ap = metrics.average_precision_score(targets, scores, None)
    auc = metrics.roc_auc_score(targets, scores, None)
    pos_class_indices, precision_at_hits = instance_based.positive_class_precisions(
        targets, scores)
    lwlrap, per_class_lwlrap, weight_per_class = instance_based.lwlrap_from_precisions(
        precision_at_hits, pos_class_indices, num_classes=targets.shape[1])
    overall_results = {
        'validation': {
            'mF1': np.mean(f1),
            'mP': np.mean(p),
            'mR': np.mean(r),
            'mAP': np.mean(ap),
            'mAUC': np.mean(auc),
            'lwlrap': lwlrap,
        }
    }
    event_validation_results = {}
    labels = load_json(exp_dir / 'events.json')
    for i, label in enumerate(labels):
        event_validation_results[label] = {
            'F1': f1[i],
            'P': p[i],
            'R': r[i],
            'AP': ap[i],
            'AUC': auc[i],
            'lwlrap': per_class_lwlrap[i],
        }

    outputs = []
    with torch.no_grad():
        for example in tqdm(test_data):
            example = model.example_to_device(example, device)
            (y, seq_len), _ = model(example)
            y = Mean(axis=-1)(y, seq_len)
            outputs.append((
                example['example_id'],
                y.cpu().detach().numpy(),
                example['events'].cpu().detach().numpy(),
            ))

    example_ids, scores, targets = list(zip(*outputs))
    example_ids = np.concatenate(example_ids).tolist()
    scores = np.concatenate(scores)
    targets = np.concatenate(targets)
    decisions = scores > thresholds
    f1, p, r = instance_based.fscore(targets, decisions, event_wise=True)
    ap = metrics.average_precision_score(targets, scores, None)
    auc = metrics.roc_auc_score(targets, scores, None)
    pos_class_indices, precision_at_hits = instance_based.positive_class_precisions(
        targets, scores)
    lwlrap, per_class_lwlrap, weight_per_class = instance_based.lwlrap_from_precisions(
        precision_at_hits, pos_class_indices, num_classes=targets.shape[1])
    overall_results['test'] = {
        'mF1': np.mean(f1),
        'mP': np.mean(p),
        'mR': np.mean(r),
        'mAP': np.mean(ap),
        'mAUC': np.mean(auc),
        'lwlrap': lwlrap,
    }
    dump_json(overall_results,
              storage_dir / 'overall.json',
              indent=4,
              sort_keys=False)
    event_results = {}
    for i, label in sorted(enumerate(labels),
                           key=lambda x: ap[x[0]],
                           reverse=True):
        event_results[label] = {
            'validation': event_validation_results[label],
            'test': {
                'F1': f1[i],
                'P': p[i],
                'R': r[i],
                'AP': ap[i],
                'AUC': auc[i],
                'lwlrap': per_class_lwlrap[i],
            },
        }
    dump_json(event_results,
              storage_dir / 'event_wise.json',
              indent=4,
              sort_keys=False)
    fp = np.argwhere(decisions * (1 - targets))
    dump_json(sorted([(example_ids[n], labels[i]) for n, i in fp]),
              storage_dir / 'fp.json',
              indent=4,
              sort_keys=False)
    fn = np.argwhere((1 - decisions) * targets)
    dump_json(sorted([(example_ids[n], labels[i]) for n, i in fn]),
              storage_dir / 'fn.json',
              indent=4,
              sort_keys=False)
    pprint(overall_results)
Example #5
0
def main(_run, exp_dir, storage_dir, database_json, test_set, max_examples,
         device):
    if IS_MASTER:
        commands.print_config(_run)

    exp_dir = Path(exp_dir)
    storage_dir = Path(storage_dir)
    audio_dir = storage_dir / 'audio'
    audio_dir.mkdir(parents=True)

    config = load_json(exp_dir / 'config.json')

    model = Model.from_storage_dir(exp_dir, consider_mpi=True)
    model.to(device)
    model.eval()

    db = JsonDatabase(database_json)
    test_data = db.get_dataset(test_set)
    if max_examples is not None:
        test_data = test_data.shuffle(
            rng=np.random.RandomState(0))[:max_examples]
    test_data = prepare_dataset(test_data,
                                audio_reader=config['audio_reader'],
                                stft=config['stft'],
                                max_length=None,
                                batch_size=1,
                                shuffle=True)
    squared_err = list()
    with torch.no_grad():
        for example in split_managed(test_data,
                                     is_indexable=False,
                                     progress_bar=True,
                                     allow_single_worker=True):
            example = model.example_to_device(example, device)
            target = example['audio_data'].squeeze(1)
            x = model.feature_extraction(example['stft'], example['seq_len'])
            x = model.wavenet.infer(
                x.squeeze(1),
                chunk_length=80_000,
                chunk_overlap=16_000,
            )
            assert target.shape == x.shape, (target.shape, x.shape)
            squared_err.extend([(ex_id, mse.cpu().detach().numpy(), x.shape[1])
                                for ex_id, mse in zip(example['example_id'], ((
                                    x - target)**2).sum(1))])

    squared_err_list = COMM.gather(squared_err, root=MASTER)

    if IS_MASTER:
        print(f'\nlen(squared_err_list): {len(squared_err_list)}')
        squared_err = []
        for i in range(len(squared_err_list)):
            squared_err.extend(squared_err_list[i])
        _, err, t = list(zip(*squared_err))
        print('rmse:', np.sqrt(np.sum(err) / np.sum(t)))
        rmse = sorted([(ex_id, np.sqrt(err / t))
                       for ex_id, err, t in squared_err],
                      key=lambda x: x[1])
        dump_json(rmse, storage_dir / 'rmse.json', indent=4, sort_keys=False)
        ex_ids_ordered = [x[0] for x in rmse]
        test_data = db.get_dataset('test_clean').shuffle(
            rng=np.random.RandomState(0))[:max_examples].filter(lambda x: x[
                'example_id'] in ex_ids_ordered[:10] + ex_ids_ordered[-10:],
                                                                lazy=False)
        test_data = prepare_dataset(test_data,
                                    audio_reader=config['audio_reader'],
                                    stft=config['stft'],
                                    max_length=10.,
                                    batch_size=1,
                                    shuffle=True)
        with torch.no_grad():
            for example in test_data:
                example = model.example_to_device(example, device)
                x = model.feature_extraction(example['stft'],
                                             example['seq_len'])
                x = model.wavenet.infer(
                    x.squeeze(1),
                    chunk_length=80_000,
                    chunk_overlap=16_000,
                )
                for i, audio in enumerate(x.cpu().detach().numpy()):
                    wavfile.write(
                        str(audio_dir / f'{example["example_id"][i]}.wav'),
                        model.sample_rate, audio)
Example #6
0
def dump(
    obj,
    path,
    mkdir=False,
    mkdir_parents=False,
    mkdir_exist_ok=False,  # Should this be an option? Should the default be True?
    unsafe=False,  # Should this be an option? Should the default be True?
    # atomic=False,  ToDo: Add atomic support
    **kwargs,
):
    """
    A generic dump function to write the obj to path.

    Infer the dump protocol (e.g. json, pickle, ...) from the path name.

    Supported formats:
     - Text:
       - json
       - yaml
     - Binary:
       - pkl: pickle
       - dill
       - h5: HDF5
       - wav
       - mat: MATLAB
       - npy: Numpy
       - npz: Numpy compressed
       - pth: Pickle with Pytorch support
     - Compressed:
       - json.gz
       - pkl.gz
       - npy.gz

    Args:
        obj: Arbitrary object that is supported from the dump protocol.
        path: str or pathlib.Path
        mkdir:
            Whether to make an mkdir id the parent dir of path does not exist.
        mkdir_parents:
        mkdir_exist_ok:
        unsafe:
            Allow unsafe dump protocol. This option is more relevant for load.
        **kwargs:
            Forwarded arguments to the particular dump function.
            Should rarely be used, because when a special property of the dump
            function/protocol is used, use directly that dump function.

    Returns:

    """
    path = normalize_path(path, allow_fd=False)
    if mkdir:
        if mkdir_exist_ok:
            # Assume that in most cases the dir exists.
            # -> try first to reduce io requests
            try:
                return dump(obj, path, unsafe=unsafe, **kwargs)
            except FileNotFoundError:
                pass
        path.parent.mkdir(parents=mkdir_parents, exist_ok=mkdir_exist_ok)

    if str(path).endswith(".json"):
        from paderbox.io import dump_json
        dump_json(obj, path, **kwargs)
    elif str(path).endswith(".pkl"):
        assert unsafe, (unsafe, path)
        with path.open("wb") as fp:
            pickle.dump(obj, fp, protocol=pickle.HIGHEST_PROTOCOL, **kwargs)
    elif str(path).endswith(".dill"):
        assert unsafe, (unsafe, path)
        with path.open("wb") as fp:
            import dill
            dill.dump(obj, fp, **kwargs)
    elif str(path).endswith(".h5"):
        from paderbox.io.hdf5 import dump_hdf5
        dump_hdf5(obj, path, **kwargs)
    elif str(path).endswith(".yaml"):
        if unsafe:
            from paderbox.io.yaml_module import dump_yaml_unsafe
            dump_yaml_unsafe(obj, path, **kwargs)
        else:
            from paderbox.io.yaml_module import dump_yaml
            dump_yaml(obj, path, **kwargs)
    elif str(path).endswith(".gz"):
        assert len(kwargs) == 0, kwargs
        with gzip.GzipFile(path, 'wb', compresslevel=1) as f:
            if str(path).endswith(".json.gz"):
                f.write(json.dumps(obj).encode())
            elif str(path).endswith(".pkl.gz"):
                assert unsafe, (unsafe, path)
                pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
            elif str(path).endswith(".npy.gz"):
                np.save(f, obj, allow_pickle=unsafe)
            else:
                raise ValueError(path)
    elif str(path).endswith(".wav"):
        from paderbox.io import dump_audio
        if np.ndim(obj) == 1:
            pass
        elif np.ndim(obj) == 2:
            assert np.shape(obj)[0] < 20, (np.shape(obj), obj)
        else:
            raise AssertionError(('Expect ndim in [1, 2]', np.shape(obj), obj))
        with path.open("wb") as fp:  # Throws better exception msg
            dump_audio(obj, fp, **kwargs)
    elif str(path).endswith('.mat'):
        import scipy.io as sio
        sio.savemat(path, obj, **kwargs)
    elif str(path).endswith('.npy'):
        np.save(str(path), obj, allow_pickle=unsafe, **kwargs)
    elif str(path).endswith('.npz'):
        assert unsafe, (unsafe, path)
        assert len(kwargs) == 0, kwargs
        if isinstance(obj, dict):
            np.savez(str(path), **obj)
        else:
            np.savez(str(path), obj)
    elif str(path).endswith('.pth'):
        assert unsafe, (unsafe, path)
        import torch
        torch.save(obj, str(path), **kwargs)
    else:
        raise ValueError('Unsupported suffix:', path)
Example #7
0
def main(_run, model_path, load_ckpt, batch_size, device, store_misclassified):
    if IS_MASTER:
        commands.print_config(_run)

    model_path = Path(model_path)
    eval_dir = get_new_subdir(model_path / 'eval',
                              id_naming='time',
                              consider_mpi=True)
    # perform evaluation on a sub-set (10%) of the dataset used for training
    config = load_json(model_path / 'config.json')
    database_json = config['database_json']
    dataset = config['dataset']

    model = pt.Model.from_storage_dir(model_path,
                                      checkpoint_name=load_ckpt,
                                      consider_mpi=True)
    model.to(device)
    # Turn on evaluation mode for, e.g., BatchNorm and Dropout modules
    model.eval()

    _, _, test_set = get_datasets(model_path,
                                  database_json,
                                  dataset,
                                  batch_size,
                                  return_indexable=device == 'cpu')
    with torch.no_grad():
        summary = dict(misclassified_examples=dict(),
                       correct_classified_examples=dict(),
                       hits=list())
        for batch in split_managed(test_set,
                                   is_indexable=device == 'cpu',
                                   progress_bar=True,
                                   allow_single_worker=True):
            output = model(pt.data.example_to_device(batch, device))
            prediction = torch.argmax(output, dim=-1).cpu().numpy()
            confidence = torch.softmax(output, dim=-1).max(dim=-1).values.cpu()\
                .numpy()
            label = np.array(batch['speaker_id'])
            hits = (label == prediction).astype('bool')
            summary['hits'].extend(hits.tolist())
            summary['misclassified_examples'].update({
                k: {
                    'true_label': v1,
                    'predicted_label': v2,
                    'audio_path': v3,
                    'confidence': f'{v4:.2%}',
                }
                for k, v1, v2, v3, v4 in zip(
                    np.array(batch['example_id'])[~hits], label[~hits],
                    prediction[~hits],
                    np.array(batch['audio_path'])[~hits], confidence[~hits])
            })
            # for each correct predicted label, collect the audio paths
            correct_classified = summary['correct_classified_examples']
            summary['correct_classified_examples'].update({
                k: correct_classified[k] +
                [v] if k in correct_classified.keys() else [v]
                for k, v in zip(prediction[hits],
                                np.array(batch['audio_path'])[hits])
            })

    summary_list = COMM.gather(summary, root=MASTER)

    if IS_MASTER:
        print(f'\nlen(summary_list): {len(summary_list)}')
        if len(summary_list) > 1:
            summary = dict(
                misclassified_examples=dict(),
                correct_classified_examples=dict(),
                hits=list(),
            )
            for partial_summary in summary_list:
                summary['hits'].extend(partial_summary['hits'])
                summary['misclassified_examples'].update(
                    partial_summary['misclassified_examples'])
                for label, audio_path_list in \
                        partial_summary['correct_classified_examples'].items():
                    summary['correct_classified_examples'].update({
                        label:
                        summary['correct_classified_examples'][label] +
                        audio_path_list if label
                        in summary['correct_classified_examples'].keys() else
                        audio_path_list
                    })
        hits = summary['hits']
        misclassified_examples = summary['misclassified_examples']
        correct_classified_examples = summary['correct_classified_examples']
        accuracy = np.array(hits).astype('float').mean()
        if store_misclassified:
            misclassified_dir = eval_dir / 'misclassified_examples'
            for example_id, v in misclassified_examples.items():
                label, prediction_label, audio_path, _ = v.values()
                try:
                    predicted_speaker_audio_path = \
                        correct_classified_examples[prediction_label][0]
                    example_dir = \
                        misclassified_dir / f'{example_id}_{label}_{prediction_label}'
                    example_dir.mkdir(parents=True)
                    os.symlink(audio_path, example_dir / 'example.wav')
                    os.symlink(predicted_speaker_audio_path,
                               example_dir / 'predicted_speaker_example.wav')
                except KeyError:
                    warnings.warn(
                        'There were no correctly predicted inputs from speaker '
                        f'with speaker label {prediction_label}')
        outputs = dict(
            accuracy=f'{accuracy:.2%} ({np.sum(hits)}/{len(hits)})',
            misclassifications=misclassified_examples,
        )
        print(f'Speaker classification accuracy on test set: {accuracy:.2%}')
        print(f'Wrote results to {eval_dir / "results.json"}')
        dump_json(outputs, eval_dir / 'results.json')