Ejemplo n.º 1
0
    def __init__(
        self,
        json_path,
        target,
        dset,
        sample_rate=8000,
        single_channel=True,
        segment=4.0,
        nondefault_nsrc=None,
        normalize_audio=False,
    ):
        try:
            import sms_wsj  # noqa
        except ModuleNotFoundError:
            import warnings

            warnings.warn(
                "Some of the functionality relies on the sms_wsj package "
                "downloadable from https://github.com/fgnt/sms_wsj ."
                "The user is encouraged to install the package")
        super().__init__()
        if target not in SMS_TARGETS.keys():
            raise ValueError("Unexpected task {}, expected one of "
                             "{}".format(target, SMS_TARGETS.keys()))

        # Task setting
        self.json_path = json_path
        self.target = target
        self.target_dict = SMS_TARGETS[target]
        self.single_channel = single_channel
        self.sample_rate = sample_rate
        self.normalize_audio = normalize_audio
        self.seg_len = None if segment is None else int(segment * sample_rate)
        if not nondefault_nsrc:
            self.n_src = self.target_dict["default_nsrc"]
        else:
            assert nondefault_nsrc >= self.target_dict["default_nsrc"]
            self.n_src = nondefault_nsrc
        self.like_test = self.seg_len is None
        self.dset = dset
        self.EPS = 1e-8

        # Load json files

        from lazy_dataset.database import JsonDatabase

        db = JsonDatabase(json_path)
        dataset = db.get_dataset(dset)
        # Filter out short utterances only when segment is specified
        if not self.like_test:

            def filter_short_examples(example):
                num_samples = example["num_samples"]["observation"]
                if num_samples < self.seg_len:
                    return False
                else:
                    return True

            dataset = dataset.filter(filter_short_examples, lazy=False)
        self.dataset = dataset
Ejemplo n.º 2
0
def main(_run, _log, trainer, database_json, training_sets, validation_sets,
         audio_reader, stft, max_length_in_sec, batch_size, resume):
    commands.print_config(_run)
    trainer = Trainer.from_config(trainer)
    storage_dir = Path(trainer.storage_dir)
    storage_dir.mkdir(parents=True, exist_ok=True)
    commands.save_config(_run.config,
                         _log,
                         config_filename=str(storage_dir / 'config.json'))

    db = JsonDatabase(database_json)
    training_data = db.get_dataset(training_sets)
    validation_data = db.get_dataset(validation_sets)
    training_data = prepare_dataset(training_data,
                                    audio_reader=audio_reader,
                                    stft=stft,
                                    max_length_in_sec=max_length_in_sec,
                                    batch_size=batch_size,
                                    shuffle=True)
    validation_data = prepare_dataset(validation_data,
                                      audio_reader=audio_reader,
                                      stft=stft,
                                      max_length_in_sec=max_length_in_sec,
                                      batch_size=batch_size,
                                      shuffle=False)

    trainer.test_run(training_data, validation_data)
    trainer.register_validation_hook(validation_data)
    trainer.train(training_data, resume=resume)
Ejemplo n.º 3
0
def get_datasets(storage_dir,
                 database_json,
                 dataset,
                 batch_size=16,
                 return_indexable=False):
    db = JsonDatabase(database_json)
    ds = db.get_dataset(dataset)

    def prepare_example(example):
        example['audio_path'] = example['audio_path']['observation']
        example['speaker_id'] = example['speaker_id'].split('-')[0]
        return example

    ds = ds.map(prepare_example)

    speaker_encoder = LabelEncoder(label_key='speaker_id',
                                   storage_dir=storage_dir,
                                   to_array=True)
    speaker_encoder.initialize_labels(dataset=ds, verbose=True)
    ds = ds.map(speaker_encoder)

    # LibriSpeech (the default database) does not share speakers across
    # different datasets, i.e., the datasets, e.g. clean_100 and dev_clean, have
    # a different set of non-overlapping speakers. To guarantee the same set of
    # speakers during training, validation and evaluation, we perform a split of
    # the train set, e.g., clean_100 or clean_360.
    train_set, validate_set, test_set = train_test_split(ds)

    training_data = prepare_dataset(train_set, batch_size, training=True)
    validation_data = prepare_dataset(validate_set, batch_size, training=False)
    test_data = prepare_dataset(test_set,
                                batch_size,
                                training=False,
                                return_indexable=return_indexable)
    return training_data, validation_data, test_data
Ejemplo n.º 4
0
def prepare_and_train(_run, _log, trainer, train_dataset, validate_dataset,
                      lr_scheduler_step, lr_scheduler_gamma, load_model_from,
                      database_json):
    trainer = pt.Trainer.from_config(trainer)
    checkpoint_path = trainer.checkpoint_dir / 'ckpt_latest.pth'

    if load_model_from is not None and not checkpoint_path.is_file():
        _log.info(f'Loading model weights from {load_model_from}')
        checkpoint = torch.load(load_model_from)
        trainer.model.load_state_dict(checkpoint['model'])

    db = JsonDatabase(database_json)

    # Perform a test run to check if everything works
    trainer.test_run(
        prepare_iterable_captured(db, train_dataset),
        prepare_iterable_captured(db, validate_dataset),
    )

    # Register hooks and start the actual training
    trainer.register_validation_hook(
        prepare_iterable_captured(db, validate_dataset))

    # Learning rate scheduler
    trainer.register_hook(
        pt.train.hooks.LRSchedulerHook(
            torch.optim.lr_scheduler.StepLR(
                trainer.optimizer.optimizer,
                step_size=lr_scheduler_step,
                gamma=lr_scheduler_gamma,
            )))

    trainer.train(prepare_iterable_captured(db, train_dataset),
                  resume=checkpoint_path.is_file())
Ejemplo n.º 5
0
def main(_run, batch_size, datasets, debug, experiment_dir, database_json):
    experiment_dir = Path(experiment_dir)

    if dlp_mpi.IS_MASTER:
        sacred.commands.print_config(_run)

    model = get_model()
    db = JsonDatabase(json_path=database_json)

    model.eval()
    with torch.no_grad():
        summary = defaultdict(dict)
        for dataset in datasets:
            iterable = prepare_iterable(
                db,
                dataset,
                batch_size,
                return_keys=None,
                prefetch=False,
            )

            for batch in dlp_mpi.split_managed(iterable,
                                               is_indexable=False,
                                               progress_bar=True,
                                               allow_single_worker=debug):
                entry = dict()
                model_output = model(pt.data.example_to_device(batch))

                example_id = batch['example_id'][0]
                s = batch['s'][0]
                Y = batch['Y'][0]
                mask = model_output[0].numpy()

                Z = mask * Y[:, None, :]
                z = istft(einops.rearrange(Z, "t k f -> k t f"),
                          size=512,
                          shift=128)

                s = s[:, :z.shape[1]]
                z = z[:, :s.shape[1]]
                entry['metrics'] \
                    = pb_bss.evaluation.OutputMetrics(speech_prediction=z,
                                                      speech_source=s).as_dict()

        summary[dataset][example_id] = entry

    summary_list = dlp_mpi.gather(summary, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        print(f'len(summary_list): {len(summary_list)}')
        for partial_summary in summary_list:
            for dataset, values in partial_summary.items():
                summary[dataset].update(values)

        for dataset, values in summary.items():
            print(f'{dataset}: {len(values)}')

        result_json_path = experiment_dir / 'result.json'
        print(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)
Ejemplo n.º 6
0
def prepare_and_train(_run, _log, trainer, train_datasets, validate_datasets,
                      lr_scheduler_step, lr_scheduler_gamma, load_model_from,
                      database_jsons):
    trainer = get_trainer(trainer, load_model_from)

    if isinstance(database_jsons, str):
        database_jsons = database_jsons.split(',')

    db = JsonDatabase(database_jsons)

    # Perform a test run to check if everything works
    trainer.test_run(
        prepare_iterable_captured(db, train_datasets),
        prepare_iterable_captured(db, validate_datasets),
    )

    # Register hooks and start the actual training
    trainer.register_validation_hook(
        prepare_iterable_captured(db, validate_datasets))

    # Learning rate scheduler
    trainer.register_hook(
        pt.train.hooks.LRSchedulerHook(
            torch.optim.lr_scheduler.StepLR(
                trainer.optimizer.optimizer,
                step_size=lr_scheduler_step,
                gamma=lr_scheduler_gamma,
            )))

    trainer.train(prepare_iterable_captured(db, train_datasets),
                  resume=trainer.checkpoint_dir.exists())
Ejemplo n.º 7
0
def test_run(_config, _run, train_dataset, validate_dataset, database_json):
    sacred.commands.print_config(_run)
    trainer = pt.Trainer.from_config(_config["trainer"])

    db = JsonDatabase(json_path=database_json)

    trainer.test_run(
        prepare_dataset_captured(db, train_dataset),
        prepare_dataset_captured(db, validate_dataset),
    )
Ejemplo n.º 8
0
def test_run(_run, _log, trainer, train_dataset, validate_dataset,
             load_model_from, database_json):
    trainer = get_trainer(trainer, load_model_from, _log)

    db = JsonDatabase(database_json)

    # Perform a test run to check if everything works
    trainer.test_run(
        prepare_dataset_captured(db, train_dataset, shuffle=True),
        prepare_dataset_captured(db, validate_dataset, shuffle=True),
    )
Ejemplo n.º 9
0
def test_run(storage_dir, database_json):
    model = SimpleMaskEstimator(513)
    print(f'Simple training for the following model: {model}')
    database = JsonDatabase(database_json)
    train_dataset = get_train_dataset(database)
    validation_dataset = get_validation_dataset(database)
    trainer = pt.train.trainer.Trainer(model,
                                       storage_dir,
                                       optimizer=pt.train.optimizer.Adam(),
                                       stop_trigger=(int(1e5), 'iteration'))
    trainer.test_run(train_dataset, validation_dataset)
Ejemplo n.º 10
0
    def test_json_database_multiple(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            tmpdir = Path(tmpdir)

            d1_path = tmpdir / 'd1.json'
            with d1_path.open('w') as fd:
                json.dump(self.json, fd)

            d2_path = tmpdir / 'd2.json'
            with d2_path.open('w') as fd:
                json.dump(self.json2, fd)

            db = JsonDatabase(d1_path, d2_path)
            assert len(db.dataset_names) == 5

            db = JsonDatabase([d1_path, d2_path])
            assert len(db.dataset_names) == 5

            db = JsonDatabase(json_path=d1_path)
            assert len(db.dataset_names) == 3

            db = JsonDatabase(json_path=(d1_path, d2_path))
            assert len(db.dataset_names) == 5

            with pytest.raises(AssertionError):
                # Test metadata check
                _ = JsonDatabase(d2_path, d1_path).dataset_names

            with pytest.raises(AssertionError):
                # Test duplicate dataset name check
                _ = JsonDatabase(d1_path, d1_path).dataset_names

            with pytest.raises(AssertionError):
                DictDatabase([d1_path], d2_path)
Ejemplo n.º 11
0
def test_run(_run, _log, trainer, train_datasets, validate_datasets,
             load_model_from, database_jsons):
    trainer = get_trainer(trainer, load_model_from)

    if isinstance(database_jsons, str):
        database_jsons = database_jsons.split(',')

    db = JsonDatabase(database_jsons)

    # Perform a test run to check if everything works
    trainer.test_run(
        prepare_iterable_captured(db, train_datasets),
        prepare_iterable_captured(db, validate_datasets),
    )
Ejemplo n.º 12
0
def train(storage_dir, database_json):
    model = SimpleMaskEstimator(513)
    print(f'Simple training for the following model: {model}')
    database = JsonDatabase(database_json)
    train_dataset = get_train_dataset(database)
    validation_dataset = get_validation_dataset(database)
    trainer = pt.Trainer(model,
                         storage_dir,
                         optimizer=pt.train.optimizer.Adam(),
                         stop_trigger=(int(1e5), 'iteration'))
    trainer.test_run(train_dataset, validation_dataset)
    trainer.register_validation_hook(validation_dataset,
                                     n_back_off=5,
                                     lr_update_factor=1 / 10,
                                     back_off_patience=1,
                                     early_stopping_patience=None)
    trainer.train(train_dataset)
def create_json(db_dir, intermediate_json_path, write_all, snr_range=(20, 30)):
    db = JsonDatabase(intermediate_json_path)
    json_dict = dict(datasets=dict())
    database_dict = db.data['datasets']

    if write_all:
        key_mapper = KEY_MAPPER
    else:
        key_mapper = {'observation': 'observation'}

    for dataset_name, dataset in database_dict.items():
        dataset_dict = dict()
        for ex_id, ex in dataset.items():
            for key, data_type in key_mapper.items():
                current_path = db_dir / data_type / dataset_name
                if key in ['observation', 'noise_image']:
                    ex['audio_path'][key] = str(current_path / f'{ex_id}.wav')
                else:
                    ex['audio_path'][key] = [
                        str(current_path / f'{ex_id}_{k}.wav')
                        for k in range(len(ex['speaker_id']))
                    ]

            ex['audio_path']['speech_source'] = [
                # .../sms_wsj/cache/wsj_8k_zeromean/13-11.1/wsj1/si_tr_s/4ax/4axc0218.wav
                str(db_dir.joinpath(*Path(rir).parts[-6:]))
                for rir in ex['audio_path']['speech_source']
            ]

            ex['audio_path']['rir'] = [
                # .../sms_wsj/cache/rirs/train_si284/0/h_0.wav
                str(db_dir.joinpath(*Path(rir).parts[-4:]))
                for rir in ex['audio_path']['rir']
            ]

            rng = _example_id_to_rng(ex_id)
            snr = rng.uniform(*snr_range)
            if 'dataset' in ex:
                del ex['dataset']
            ex["snr"] = snr
            dataset_dict[ex_id] = ex
            json_dict['datasets'][dataset_name] = dataset_dict
    return json_dict
Ejemplo n.º 14
0
def prepare_and_train(_config, _run, train_dataset, validate_dataset,
                      database_json):
    """ Prepares the train and validation dataset from the database object """

    sacred.commands.print_config(_run)
    trainer = pt.Trainer.from_config(_config["trainer"])
    checkpoint_path = trainer.checkpoint_dir / 'ckpt_latest.pth'

    db = JsonDatabase(json_path=database_json)
    print(repr(train_dataset), repr(validate_dataset))

    trainer.test_run(
        prepare_dataset_captured(db, train_dataset),
        prepare_dataset_captured(db, validate_dataset),
    )
    trainer.register_validation_hook(
        prepare_dataset_captured(db, validate_dataset))
    trainer.train(prepare_dataset_captured(db, train_dataset),
                  resume=checkpoint_path.is_file())
Ejemplo n.º 15
0
def prepare_and_train(_run, _log, trainer, train_dataset, validate_dataset,
                      lr_scheduler_step, lr_scheduler_gamma,
                      load_model_from, database_json):
    trainer = get_trainer(trainer, load_model_from)

    db = JsonDatabase(database_json)

    # Perform a test run to check if everything works
    trainer.test_run(
        prepare_iterable_captured(db, train_dataset),
        prepare_iterable_captured(db, validate_dataset),
    )

    # Register hooks and start the actual training

    # Learning rate scheduler
    if lr_scheduler_step:
        trainer.register_hook(pt.train.hooks.LRSchedulerHook(
            torch.optim.lr_scheduler.StepLR(
                trainer.optimizer.optimizer,
                step_size=lr_scheduler_step,
                gamma=lr_scheduler_gamma,
            )
        ))

        # Don't use LR back-off
        trainer.register_validation_hook(
            prepare_iterable_captured(db, validate_dataset),
        )
    else:
        # Use LR back-off
        trainer.register_validation_hook(
            prepare_iterable_captured(db, validate_dataset),
            n_back_off=5, back_off_patience=3
        )

    trainer.train(
        prepare_iterable_captured(db, train_dataset),
        resume=trainer.checkpoint_dir.exists()
    )
Ejemplo n.º 16
0
def main(dst_dir, json_path, write_all, new_json_path, snr_range):
    json_path = Path(json_path).expanduser().resolve()
    dst_dir = Path(dst_dir).expanduser().resolve()
    if dlp_mpi.IS_MASTER:
        assert json_path.exists(), json_path
        dst_dir.mkdir(exist_ok=True, parents=True)
        if not any([(dst_dir / data_type).exists()
                    for data_type in type_mapper.keys()]):
            write_files = True
        else:
            write_files = False
            num_wav_files = len(check_files(dst_dir))
            if write_all and num_wav_files == (2 * 2 + 2) * 32000:
                print('Wav files seem to exist. They are not overwritten.')
            elif not write_all and num_wav_files == 32000 and (
                    dst_dir / 'observation').exists():
                print('Wav files seem to exist. They are not overwritten.')
            else:
                raise ValueError(
                    'Not all wav files exist. However, the directory structure'
                    ' already exists.')
    else:
        write_files = None
    write_files = dlp_mpi.COMM.bcast(write_files, root=dlp_mpi.MASTER)
    db = JsonDatabase(json_path)
    if write_files:
        write_wavs(dst_dir, db, write_all=write_all, snr_range=snr_range)

    if dlp_mpi.IS_MASTER and new_json_path:
        print(f'Creating a new json and saving it to {new_json_path}')
        new_json_path = Path(new_json_path).expanduser().resolve()
        updated_json = create_json(dst_dir, db, write_all, snr_range=snr_range)
        new_json_path.parent.mkdir(exist_ok=True, parents=True)
        with new_json_path.open('w') as f:
            json.dump(updated_json, f, indent=4, ensure_ascii=False)
        print(f'{json_path} written')
Ejemplo n.º 17
0
def run(_config, egs_path, json_path, stage, end_stage, gmm_dir, ali_data_type,
        train_data_type, target_speaker, channels, kaldi_cmd, num_jobs):
    sms_db = JsonDatabase(json_path)
    sms_kaldi_dir = Path(egs_path).resolve().expanduser()
    sms_kaldi_dir = sms_kaldi_dir / train_data_type / 's5'
    if stage <= 1 < end_stage:
        create_kaldi_dir(sms_kaldi_dir)

    if kaldi_cmd == 'ssh.pl':
        if 'CCS_NODEFILE' in os.environ:
            pc2_environ(sms_kaldi_dir)
        with (sms_kaldi_dir / 'cmd.sh').open('a') as fd:
            fd.writelines('export train_cmd="ssh.pl"')
    elif kaldi_cmd == 'run.pl':
        with (sms_kaldi_dir / 'cmd.sh').open('a') as fd:
            fd.writelines('export train_cmd="run.pl"')
    else:
        raise ValueError(kaldi_cmd)

    if gmm_dir is None:
        gmm = 'tri4b'
    else:
        gmm_dir = Path(gmm_dir)
        gmm = gmm_dir.name
    if stage <= 2 < end_stage:
        if gmm_dir is None:
            create_data_dir(sms_kaldi_dir,
                            db=sms_db,
                            data_type='wsj_8k',
                            target_speaker=target_speaker)
            print('Start training tri3 model on wsj_8k')
            run_process([
                f'{sms_kaldi_dir}/local_sms/get_tri3_model.bash', '--dest_dir',
                f'{sms_kaldi_dir}', '--nj',
                str(num_jobs)
            ],
                        cwd=str(sms_kaldi_dir),
                        stdout=None,
                        stderr=None)
        else:
            assert gmm_dir.exists()
            gmm_parent_dir = sms_kaldi_dir / 'exp' / 'wsj_8k'
            gmm_parent_dir.mkdir(parents=True)
            shutil.copytree(gmm_dir, gmm_parent_dir / gmm)

    if stage <= 3 < end_stage and not ali_data_type == train_data_type:
        create_data_dir(sms_kaldi_dir,
                        db=sms_db,
                        data_type=ali_data_type,
                        ref_channels=channels,
                        target_speaker=target_speaker)

    if stage <= 4 < end_stage:
        create_data_dir(sms_kaldi_dir,
                        db=sms_db,
                        data_type=train_data_type,
                        ref_channels=channels,
                        target_speaker=target_speaker)

    if stage <= 16 < end_stage:
        print('Prepare data for nnet3 model training on sms_wsj')
        run_process([
            f'{sms_kaldi_dir}/local_sms/prepare_nnet3_model_training.bash',
            '--dest_dir', f'{sms_kaldi_dir}', '--cv_sets', "cv_dev93",
            '--stage',
            str(stage), '--gmm_data_type', 'wsj_8k', '--gmm', gmm,
            '--ali_data_type', ali_data_type, '--dataset', train_data_type,
            '--nj',
            str(num_jobs)
        ],
                    cwd=str(sms_kaldi_dir),
                    stdout=None,
                    stderr=None)

    if stage <= 20 and end_stage >= 17:
        print('Start training nnet3 model on sms_wsj')
        run_process([
            f'{sms_kaldi_dir}/local_sms/get_nnet3_model.bash', '--dest_dir',
            f'{sms_kaldi_dir}', '--cv_sets', '"cv_dev93"', '--stage',
            str(stage), '--gmm_data_type', 'wsj_8k', '--gmm', gmm,
            '--ali_data_type', ali_data_type, '--dataset', train_data_type,
            '--nj',
            str(num_jobs)
        ],
                    cwd=str(sms_kaldi_dir),
                    stdout=None,
                    stderr=None)
Ejemplo n.º 18
0
def _create_data_dir(
        get_wer_command_fn, kaldi_dir, db=None, json_path=None,
        dataset_names=None, data_type='wsj_8k', target_speaker=0,
        ref_channels=0,
):
    """

    Args:
        get_wer_command_fn:
        kaldi_dir:
        db:
        json_path:
        dataset_names:
        data_type:
        target_speaker:
        ref_channels:

    Returns:

    """

    assert not (db is None and json_path is None), (db, json_path)
    if db is None:
        db = JsonDatabase(json_path)

    kaldi_dir = Path(kaldi_dir).expanduser().resolve()

    data_dir = kaldi_dir / 'data' / data_type
    data_dir.mkdir(exist_ok=True, parents=True)

    if not isinstance(ref_channels, (list, tuple)):
        ref_channels = [ref_channels]
    example_id_to_wav = dict()
    example_id_to_speaker = dict()
    example_id_to_trans = dict()
    example_id_to_duration = dict()
    speaker_to_gender = defaultdict(lambda: defaultdict(list))
    dataset_to_example_id = defaultdict(list)

    if dataset_names is None:
        dataset_names = ('train_si284', 'cv_dev93', 'test_eval92')
    elif isinstance(dataset_names, str):
        dataset_names = [dataset_names]
    if not isinstance(target_speaker, (list, tuple)):
        target_speaker = [target_speaker]
    assert not any([
        (data_dir / dataset_name).exists() for dataset_name in dataset_names
    ]), (
        'One of the following directories already exists: '
        f'{[data_dir / ds_name for ds_name in dataset_names]}\n'
        'Delete them if you want to restart this stage'
    )

    print(
        'Create data dir for '
        f'{", ".join([f"{data_type}/{ds_name}" for ds_name in dataset_names])} '
        'data'
    )

    dataset = db.get_dataset(dataset_names)
    for example in dataset:
        for ref_ch in ref_channels:
            org_example_id = example['example_id']
            dataset_name = example['dataset']
            for t_spk in target_speaker:
                speaker_id = example['speaker_id'][t_spk]
                example_id = speaker_id + '_' + org_example_id
                example_id += f'_c{ref_ch}' if len(ref_channels) > 1 else ''
                example_id_to_wav[example_id] = get_wer_command_fn(
                    example, ref_ch=ref_ch, spk=t_spk)
                try:
                    transcription = example['kaldi_transcription'][t_spk]
                except KeyError:
                    transcription = example['transcription'][t_spk]
                example_id_to_trans[example_id] = transcription

                example_id_to_speaker[example_id] = speaker_id
                gender = example['gender'][t_spk]
                speaker_to_gender[dataset_name][speaker_id] = gender
                if isinstance(example['num_samples'], dict):
                    num_samples = example['num_samples']['observation']
                else:
                    num_samples = example['num_samples']
                example_id_to_duration[
                    example_id] = f"{num_samples / SAMPLE_RATE:.2f}"
                dataset_to_example_id[dataset_name].append(example_id)

    assert len(example_id_to_speaker) > 0, dataset
    for dataset_name in dataset_names:
        path = data_dir / dataset_name
        path.mkdir(exist_ok=False, parents=False)
        for name, dictionary in (
                ("utt2spk", example_id_to_speaker),
                ("text", example_id_to_trans),
                ("utt2dur", example_id_to_duration),
                ("wav.scp", example_id_to_wav)
        ):
            dictionary = {key: value for key, value in dictionary.items()
                          if key in dataset_to_example_id[dataset_name]}

            assert len(dictionary) > 0, (dataset_name, name)
            if name == 'utt2dur':
                dump_keyed_lines(dictionary, path / 'reco2dur')
            dump_keyed_lines(dictionary, path / name)
        dictionary = speaker_to_gender[dataset_name]
        assert len(dictionary) > 0, (dataset_name, name)
        dump_keyed_lines(dictionary, path / 'spk2gender')
        run_process([
            f'utils/fix_data_dir.sh', f'{path}'],
            cwd=str(kaldi_dir), stdout=None, stderr=None
        )
Ejemplo n.º 19
0
def run(_config, _run, audio_dir, kaldi_data_dir, json_path):
    assert Path(kaldi_root).exists(), kaldi_root

    assert len(ex.current_run.observers) == 1, (
        'FileObserver` missing. Add a `FileObserver` with `-F foo/bar/`.')
    base_dir = Path(ex.current_run.observers[0].basedir)
    base_dir = base_dir.expanduser().resolve()
    if audio_dir is not None:
        audio_dir = Path(audio_dir).expanduser().resolve()
        assert audio_dir.exists(), audio_dir
        json_path = Path(json_path).expanduser().resolve()
        assert json_path.exists(), json_path
        db = JsonDatabase(json_path)
    elif kaldi_data_dir is not None:
        kaldi_data_dir = Path(kaldi_data_dir).expanduser().resolve()
        assert kaldi_data_dir.exists(), kaldi_data_dir
        assert json_path is None, json_path
    elif json_path is not None:
        json_path = Path(json_path).expanduser().resolve()
        assert json_path.exists(), json_path
        db = JsonDatabase(json_path)
    else:
        raise ValueError('Either json_path, audio_dir or kaldi_data_dir has'
                         'to be defined.')
    if _config['model_egs_dir'] is None:
        model_egs_dir = kaldi_root / 'egs' / 'sms_wsj' / 's5'
    else:
        model_egs_dir = Path(_config['model_egs_dir']).expanduser().resolve()
    assert model_egs_dir.exists(), model_egs_dir

    dataset_names = _config['dataset_names']
    if not isinstance(dataset_names, (tuple, list)):
        dataset_names = [dataset_names]
    data_type = _config['data_type']
    if not isinstance(data_type, (tuple, list)):
        data_type = [data_type]

    kaldi_cmd = _config['kaldi_cmd']
    if not base_dir == model_egs_dir and not (base_dir / 'steps').exists():
        create_kaldi_dir(base_dir, model_egs_dir, exist_ok=True)
        if kaldi_cmd == 'ssh.pl':
            CCS_NODEFILE = Path(os.environ['CCS_NODEFILE'])
            (base_dir / '.queue').mkdir()
            (base_dir / '.queue' / 'machines').write_text(
                CCS_NODEFILE.read_text())
        elif kaldi_cmd == 'run.pl':
            pass
        else:
            raise ValueError(kaldi_cmd)

    for d_type in data_type:
        for dset in dataset_names:
            dataset_dir = base_dir / 'data' / d_type / dset
            if audio_dir is not None:
                assert len(data_type) == 1, data_type
                create_dir(audio_dir,
                           base_dir=base_dir,
                           db=db,
                           dataset_names=dset)
            elif kaldi_data_dir is None:
                create_data_dir(base_dir,
                                db=db,
                                data_type=d_type,
                                dataset_names=dset,
                                ref_channels=_config['ref_channels'],
                                target_speaker=_config['target_speaker'])
            else:
                assert len(data_type) == 1, (
                    'when using a predefined kaldi_data_dir not more then one '
                    'data_type should be defined. Better use the decode'
                    'command directly')
                copytree(kaldi_data_dir / dset, dataset_dir, symlinks=True)
                run_process([f'utils/fix_data_dir.sh', f'{dataset_dir}'],
                            cwd=str(base_dir),
                            stdout=None,
                            stderr=None)

            decode(base_dir=base_dir,
                   model_egs_dir=model_egs_dir,
                   dataset_dir=dataset_dir,
                   model_dir=check_config_element(_config['model_dir']),
                   ivector_dir=check_config_element(_config['ivector_dir']),
                   extractor_dir=check_config_element(
                       _config['extractor_dir']),
                   data_type=d_type)
Ejemplo n.º 20
0
def main(_run, batch_size, datasets, debug, experiment_dir, database_json,
         _log):
    experiment_dir = Path(experiment_dir)

    if dlp_mpi.IS_MASTER:
        sacred.commands.print_config(_run)

    model = get_model()
    db = JsonDatabase(json_path=database_json)

    model.eval()
    with torch.no_grad():
        summary = defaultdict(dict)
        for dataset_name in datasets:
            dataset = prepare_dataset(db,
                                      dataset_name,
                                      batch_size,
                                      return_keys=None,
                                      prefetch=False,
                                      shuffle=False)

            for batch in dlp_mpi.split_managed(dataset,
                                               is_indexable=True,
                                               progress_bar=True,
                                               allow_single_worker=debug):
                entry = dict()
                model_output = model(model.example_to_device(batch))

                example_id = batch['example_id'][0]
                s = batch['s'][0]
                Y = batch['Y'][0]
                mask = model_output[0].numpy()

                Z = mask * Y[:, None, :]
                z = istft(einops.rearrange(Z, "t k f -> k t f"),
                          size=512,
                          shift=128)

                s = s[:, :z.shape[1]]
                z = z[:, :s.shape[1]]

                input_metrics = pb_bss.evaluation.InputMetrics(
                    observation=batch['y'][0][None, :],
                    speech_source=s,
                    sample_rate=8000,
                    enable_si_sdr=False,
                )

                output_metrics = pb_bss.evaluation.OutputMetrics(
                    speech_prediction=z,
                    speech_source=s,
                    sample_rate=8000,
                    enable_si_sdr=False,
                )
                entry['input'] = dict(mir_eval=input_metrics.mir_eval, )
                entry['output'] = dict(mir_eval={
                    k: v
                    for k, v in output_metrics.mir_eval.items()
                    if k != 'selection'
                }, )

                entry['improvement'] = pb.utils.nested.nested_op(
                    operator.sub,
                    entry['output'],
                    entry['input'],
                )
                entry['selection'] = output_metrics.mir_eval['selection']

                summary[dataset][example_id] = entry

    summary_list = dlp_mpi.gather(summary, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        _log.info(f'len(summary_list): {len(summary_list)}')
        summary = pb.utils.nested.nested_merge(*summary_list)

        for dataset, values in summary.items():
            _log.info(f'{dataset}: {len(values)}')
            assert len(values) == len(
                db.get_dataset(dataset)
            ), 'Number of results needs to match length of dataset!'
        result_json_path = experiment_dir / 'result.json'
        _log.info(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)

        # Compute and save mean of metrics
        means = compute_means(summary)
        mean_json_path = experiment_dir / 'means.json'
        _log.info(f"Saving means to: {mean_json_path}")
        pb.io.dump_json(means, mean_json_path)
Ejemplo n.º 21
0
def main(_run, datasets, debug, experiment_dir, export_audio, sample_rate,
         _log, database_json, oracle_num_spk, max_iterations):
    experiment_dir = Path(experiment_dir)

    if mpi.IS_MASTER:
        sacred.commands.print_config(_run)
        dump_config_and_makefile()

    model = get_model()
    db = JsonDatabase(database_json)

    model.eval()
    with torch.no_grad():
        summary = defaultdict(dict)
        for dataset in datasets:
            iterable = prepare_iterable(
                db,
                dataset,
                1,
                chunk_size=-1,
                prefetch=False,
                shuffle=False,
                iterator_slice=slice(mpi.RANK, 20 if debug else None,
                                     mpi.SIZE),
            )

            if export_audio:
                (experiment_dir / 'audio' / dataset).mkdir(parents=True,
                                                           exist_ok=True)

            for batch in tqdm(
                    iterable,
                    total=len(iterable),
                    disable=not mpi.IS_MASTER,
                    desc=dataset,
            ):
                example_id = batch['example_id'][0]
                summary[dataset][example_id] = entry = dict()
                oracle_speaker_count = \
                    entry['oracle_speaker_count'] = batch['s'][0].shape[0]

                try:
                    model_output = model.decode(
                        pt.data.example_to_device(batch),
                        max_iterations=max_iterations,
                        oracle_num_speakers=oracle_speaker_count
                        if oracle_num_spk else None)

                    # Bring to numpy float64 for evaluation metrics computation
                    s = batch['s'][0].astype(np.float64)
                    z = model_output['out'][0].cpu().numpy().astype(np.float64)

                    estimated_speaker_count = \
                        entry['estimated_speaker_count'] = z.shape[0]
                    entry['source_counting_accuracy'] = \
                        estimated_speaker_count == oracle_speaker_count

                    if oracle_speaker_count == estimated_speaker_count:
                        # These evaluations don't work if the number of
                        # speakers in s and z don't match
                        entry['mir_eval'] = pb_bss.evaluation.mir_eval_sources(
                            s, z, return_dict=True)

                        # Get the correct order for si_sdr and saving
                        z = z[entry['mir_eval']['selection']]

                        entry['si_sdr'] = pb_bss.evaluation.si_sdr(s, z)
                    else:
                        warnings.warn(
                            'The number of speakers is estimated incorrectly '
                            'for some examples! The calculated SDR values '
                            'might not be representative!')

                    if export_audio:
                        entry['audio_path'] = batch['audio_path']
                        entry['audio_path'].setdefault('estimated', [])

                        for k, audio in enumerate(z):
                            audio_path = (experiment_dir / 'audio' / dataset /
                                          f'{example_id}_{k}.wav')
                            pb.io.dump_audio(audio,
                                             audio_path,
                                             sample_rate=sample_rate)
                            entry['audio_path']['estimated'].append(audio_path)
                except:
                    _log.error(f'Exception was raised in example with ID '
                               f'"{example_id}"')
                    raise

    summary_list = mpi.gather(summary, root=mpi.MASTER)

    if mpi.IS_MASTER:
        # Combine all summaries to one
        for partial_summary in summary_list:
            for dataset, values in partial_summary.items():
                summary[dataset].update(values)

        for dataset, values in summary.items():
            _log.info(f'{dataset}: {len(values)}')

        # Write summary to JSON
        result_json_path = experiment_dir / 'result.json'
        _log.info(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)

        # Compute means for some metrics
        mean_keys = [
            'mir_eval.sdr', 'mir_eval.sar', 'mir_eval.sir', 'si_sdr',
            'source_counting_accuracy'
        ]
        means = {}
        for dataset, dataset_results in summary.items():
            means[dataset] = {}
            flattened = {
                k: pb.utils.nested.flatten(v)
                for k, v in dataset_results.items()
            }
            for mean_key in mean_keys:
                try:
                    means[dataset][mean_key] = np.mean(
                        np.array([v[mean_key] for v in flattened.values()]))
                except KeyError:
                    warnings.warn(f'Couldn\'t compute mean for {mean_key}.')
            means[dataset] = pb.utils.nested.deflatten(means[dataset])

        mean_json_path = experiment_dir / 'means.json'
        _log.info(f'Exporting means: {mean_json_path}')
        pb.io.dump_json(means, mean_json_path)

        _log.info('Resulting means:')

        pprint(means)
Ejemplo n.º 22
0
def evaluate(checkpoint_path, eval_dir, database_json):
    model = SimpleMaskEstimator(513)

    model.load_checkpoint(
        checkpoint_path=checkpoint_path,
        in_checkpoint_path='model',
        consider_mpi=True
    )
    model.eval()
    if dlp_mpi.IS_MASTER:
        print(f'Start to evaluate the checkpoint {checkpoint_path.resolve()} '
              f'and will write the evaluation result to'
              f' {eval_dir / "result.json"}')
    database = JsonDatabase(database_json)
    test_dataset = get_test_dataset(database)
    with torch.no_grad():
        summary = dict(masked=dict(), beamformed=dict(), observed=dict())
        for batch in dlp_mpi.split_managed(
                test_dataset, is_indexable=True,
                progress_bar=True,
                allow_single_worker=True
        ):
            model_output = model(pt.data.example_to_device(batch))

            example_id = batch['example_id']
            s = batch['speech_source'][0][None]

            speech_mask = model_output['speech_mask_prediction'].numpy()
            Y = batch['observation_stft']
            Z_mask = speech_mask[0] * Y[0]
            z_mask = pb.transform.istft(Z_mask)[None]

            speech_mask = np.median(speech_mask, axis=0).T
            noise_mask = model_output['noise_mask_prediction'].numpy()
            noise_mask = np.median(noise_mask, axis=0).T
            Y = rearrange(Y, 'c t f -> f c t')
            target_psd = pb_bss.extraction.get_power_spectral_density_matrix(
                Y, speech_mask,
            )
            noise_psd = pb_bss.extraction.get_power_spectral_density_matrix(
                Y, noise_mask,
            )
            beamformer = pb_bss.extraction.get_bf_vector(
                'mvdr_souden',
                target_psd_matrix=target_psd,
                noise_psd_matrix=noise_psd

            )
            Z_bf = pb_bss.extraction.apply_beamforming_vector(beamformer, Y).T
            z_bf = pb.transform.istft(Z_bf)[None]

            y = batch['observation'][0][None]
            s = s[:, :z_bf.shape[1]]
            for key, signal in zip(summary.keys(), [z_mask, z_bf, y]):
                signal = signal[:, :s.shape[1]]
                entry = pb_bss.evaluation.OutputMetrics(
                    speech_prediction=signal, speech_source=s,
                    sample_rate=16000
                ).as_dict()
                entry.pop('mir_eval_selection')
                summary[key][example_id] = entry

    summary_list = dlp_mpi.COMM.gather(summary, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        print(f'\n len(summary_list): {len(summary_list)}')
        summary = dict(masked=dict(), beamformed=dict(), observed=dict())
        for partial_summary in summary_list:
            for signal_type, metric in partial_summary.items():
                summary[signal_type].update(metric)
        for signal_type, values in summary.items():
            print(signal_type)
            for metric in next(iter(values.values())).keys():
                mean = np.mean([value[metric] for key, value in values.items()
                                if '_mean' not in key])
                values[metric + '_mean'] = mean
                print(f'{metric}: {mean}')

        result_json_path = eval_dir / 'result.json'
        print(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)
Ejemplo n.º 23
0
def get_test_dataset(database: JsonDatabase):
    val_iterator = database.get_dataset('et05_simu')
    return val_iterator.map(prepare_data)
Ejemplo n.º 24
0
def main(json_path: Path, rir_dir: Path, wsj_json_path: Path, num_speakers):
    wsj_json_path = Path(wsj_json_path).expanduser().resolve()
    json_path = Path(json_path).expanduser().resolve()
    if json_path.exists():
        raise FileExistsError(json_path)
    rir_dir = Path(rir_dir).expanduser().resolve()
    assert wsj_json_path.is_file(), json_path
    assert rir_dir.exists(), rir_dir

    setup = dict(
        train_si284=dict(source_dataset_name="train_si284"),
        cv_dev93=dict(source_dataset_name="cv_dev93"),
        test_eval92=dict(source_dataset_name="test_eval92"),
    )

    rir_db = JsonDatabase(rir_dir / "scenarios.json")

    source_db = JsonDatabase(wsj_json_path)

    target_db = dict()
    target_db['datasets'] = defaultdict(dict)

    for dataset_name in setup.keys():
        source_dataset_name = setup[dataset_name]["source_dataset_name"]
        source_iterator = source_db.get_dataset(source_dataset_name)
        print(f'length of source {dataset_name}: {len(source_iterator)}')
        source_iterator = source_iterator.filter(
            filter_fn=filter_punctuation_pronunciation, lazy=False)
        print(f'length of source {dataset_name}: {len(source_iterator)} '
              '(after punctuation filter)')

        rir_iterator = rir_db.get_dataset(dataset_name)

        assert len(rir_iterator) % len(source_iterator) == 0, (
            f'To avoid a bias towards certain utterance the len '
            f'rir_iterator ({len(rir_iterator)}) should be an integer '
            f'multiple of len source_iterator ({len(source_iterator)}).')

        print(f'length of rir {dataset_name}: {len(rir_iterator)}')

        probe_path = rir_dir / dataset_name / "0"
        available_speaker_positions = len(list(probe_path.glob('h_*.wav')))
        assert num_speakers <= available_speaker_positions, (
            f'Requested {num_speakers} num_speakers, while found only '
            f'{available_speaker_positions} rirs in {probe_path}.')

        info = soundfile.info(str(rir_dir / dataset_name / "0" / "h_0.wav"))
        sample_rate_rir = info.samplerate

        ex_wsj = source_iterator.random_choice(1)[0]
        info = soundfile.SoundFile(ex_wsj['audio_path']['observation'])
        sample_rate_wsj = info.samplerate
        assert sample_rate_rir == sample_rate_wsj, (sample_rate_rir,
                                                    sample_rate_wsj)

        rir_iterator = rir_iterator.sort(
            sort_fn=functools.partial(sorted, key=int))

        source_iterator = source_iterator.sort()
        assert len(rir_iterator) % len(source_iterator) == 0
        repeats = len(rir_iterator) // len(source_iterator)
        source_iterator = source_iterator.tile(repeats)

        speaker_ids = [example['speaker_id'] for example in source_iterator]

        rng = get_rng(dataset_name, 'example_compositions')

        example_compositions = None
        for _ in range(num_speakers):
            example_compositions = extend_example_composition_greedy(
                rng,
                speaker_ids,
                example_compositions=example_compositions,
            )

        ex_dict = dict()
        assert len(rir_iterator) == len(example_compositions)
        for rir_example, example_composition in zip(rir_iterator,
                                                    example_compositions):
            source_examples = source_iterator[example_composition]

            example_id = "_".join([
                rir_example['example_id'],
                *[ex["example_id"] for ex in source_examples],
            ])

            rng = get_rng(dataset_name, example_id)
            example = get_randomized_example(
                rir_example,
                source_examples,
                rng,
                dataset_name,
                rir_dir,
            )
            ex_dict[example_id] = example

        target_db['datasets'][dataset_name] = ex_dict

    json_path.parent.mkdir(exist_ok=True, parents=True)
    with json_path.open('w') as f:
        json.dump(target_db, f, indent=2, ensure_ascii=False)
    print(f'{json_path} written')
Ejemplo n.º 25
0
class DataProvider(Configurable):
    json_path: str
    audio_reader: Callable
    train_set: dict
    validate_set: str = None
    cached_datasets: list = None
    min_audio_length: float = 1.
    train_segmenter: float = None
    test_segmenter: float = None
    train_transform: Callable = None
    test_transform: Callable = None
    train_fetcher: Callable = None
    test_fetcher: Callable = None
    label_key: str = 'events'
    discard_labelless_train_examples: bool = True
    storage_dir: str = None
    # augmentation
    min_class_examples_per_epoch: int = 0
    scale_sampling_fn: Callable = None
    mix_interval: float = 1.5
    mix_fn: Callable = None

    def __post_init__(self):
        assert self.json_path is not None
        self.db = JsonDatabase(json_path=self.json_path)

    def get_train_set(self, filter_example_ids=None):
        return self.get_dataset(self.train_set,
                                train=True,
                                filter_example_ids=filter_example_ids)

    def get_validate_set(self, filter_example_ids=None):
        return self.get_dataset(self.validate_set,
                                train=False,
                                filter_example_ids=filter_example_ids)

    def get_dataset(self,
                    dataset_names_or_raw_datasets,
                    train=False,
                    filter_example_ids=None):
        ds = self.prepare_audio(dataset_names_or_raw_datasets,
                                train=train,
                                filter_example_ids=filter_example_ids)
        ds = self.segment_transform_and_fetch(ds, train=train)
        return ds

    def prepare_audio(self,
                      dataset_names_or_raw_datasets,
                      train=False,
                      filter_example_ids=None):
        individual_audio_datasets = self._load_audio(
            dataset_names_or_raw_datasets,
            train=train,
            filter_example_ids=filter_example_ids)
        if not isinstance(individual_audio_datasets, list):
            assert isinstance(
                individual_audio_datasets,
                lazy_dataset.Dataset), type(individual_audio_datasets)
            individual_audio_datasets = [(individual_audio_datasets, 1)]
        combined_audio_dataset = self._tile_and_intersperse(
            individual_audio_datasets, shuffle=train)
        if train and self.min_class_examples_per_epoch > 0:
            assert self.label_key is not None
            raw_datasets = self.get_raw(
                dataset_names_or_raw_datasets,
                discard_labelless_examples=self.
                discard_labelless_train_examples,
                filter_example_ids=filter_example_ids,
            )
            label_counts, labels = self._count_labels(raw_datasets,
                                                      self.label_key)
            label_reps = self._compute_label_repetitions(
                label_counts, min_counts=self.min_class_examples_per_epoch)
            repetition_groups = self._build_repetition_groups(
                individual_audio_datasets, labels, label_reps)
            dataset = self._tile_and_intersperse(repetition_groups,
                                                 shuffle=train)
        else:
            dataset = combined_audio_dataset
        if train:
            # dataset = self.scale_and_mix(dataset, combined_audio_dataset)
            dataset = self.scale_and_mix(dataset, dataset)
        print(f'Total data set length:', len(dataset))
        return dataset

    def _load_audio(self,
                    dataset_names_or_raw_datasets,
                    train=False,
                    filter_example_ids=None,
                    idx=None):
        if isinstance(dataset_names_or_raw_datasets, (dict, list, tuple)):
            ds = []
            for i, name_or_ds in enumerate(dataset_names_or_raw_datasets):
                num_reps = (
                    dataset_names_or_raw_datasets[name_or_ds] if isinstance(
                        dataset_names_or_raw_datasets, dict) else
                    name_or_ds[1] if isinstance(name_or_ds,
                                                (list, tuple)) else 1)
                if num_reps == 0:
                    continue
                ds.append((self._load_audio(
                    name_or_ds[0] if isinstance(name_or_ds,
                                                (list, tuple)) else name_or_ds,
                    train=train,
                    filter_example_ids=filter_example_ids,
                    idx=i,
                ), num_reps))
            return ds
        ds = self.get_raw(
            dataset_names_or_raw_datasets,
            discard_labelless_examples=(train and
                                        self.discard_labelless_train_examples),
            filter_example_ids=filter_example_ids,
        ).map(self.audio_reader)
        cache = (self.cached_datasets is not None
                 and isinstance(dataset_names_or_raw_datasets, str)
                 and dataset_names_or_raw_datasets in self.cached_datasets)
        if cache:
            ds = ds.cache(lazy=False)

        if isinstance(dataset_names_or_raw_datasets, str):
            ds_name = " " + dataset_names_or_raw_datasets
        else:
            ds_name = ""
        if idx is not None:
            ds_name += f" [{idx}]"
        print(f'Single data set length{ds_name}:', len(ds))
        return ds

    def get_raw(
        self,
        dataset_names_or_raw_datasets,
        discard_labelless_examples=False,
        filter_example_ids=None,
    ):
        if isinstance(dataset_names_or_raw_datasets, (dict, list, tuple)):
            return list(
                filter(lambda x: x[1] > 0, [(
                    self.get_raw(
                        name_or_ds[0] if isinstance(name_or_ds,
                                                    (list,
                                                     tuple)) else name_or_ds,
                        discard_labelless_examples=discard_labelless_examples,
                        filter_example_ids=filter_example_ids,
                    ),
                    (dataset_names_or_raw_datasets[name_or_ds] if isinstance(
                        dataset_names_or_raw_datasets, dict) else
                     name_or_ds[1] if isinstance(name_or_ds,
                                                 (list, tuple)) else 1),
                ) for name_or_ds in dataset_names_or_raw_datasets]))
        elif isinstance(dataset_names_or_raw_datasets, str):
            ds = self.db.get_dataset(dataset_names_or_raw_datasets)
        else:
            assert isinstance(
                dataset_names_or_raw_datasets,
                lazy_dataset.Dataset), type(dataset_names_or_raw_datasets)
            ds = dataset_names_or_raw_datasets
        if discard_labelless_examples:
            ds = ds.filter(
                lambda ex: self.label_key in ex and ex[self.label_key],
                lazy=False)
        if filter_example_ids is not None:
            ds = ds.filter(
                lambda ex: ex['example_id'] not in filter_example_ids,
                lazy=False)
        return ds.filter(lambda ex: 'audio_length' in ex and ex['audio_length']
                         > self.min_audio_length,
                         lazy=False)

    @staticmethod
    def _tile_and_intersperse(datasets, shuffle=False):
        if shuffle:
            datasets = [(ds.shuffle(reshuffle=True), reps)
                        for ds, reps in datasets]
        return lazy_dataset.intersperse(
            *[ds.tile(reps) for ds, reps in datasets])

    def scale_and_mix(self, dataset, mixin_dataset=None):
        if mixin_dataset is None:
            mixin_dataset = dataset
        if self.scale_sampling_fn is not None:

            def scale(example):
                w = self.scale_sampling_fn()
                example['audio_data'] = example['audio_data'] * w
                return example

            dataset = dataset.map(scale)
            mixin_dataset = mixin_dataset.map(scale)

        if self.mix_interval is not None:
            # mixin_dataset = mixin_dataset.tile(
            #     math.ceil(len(dataset)/len(combined_audio_dataset)))
            assert self.mix_fn is not None
            dataset = MixtureDataset(dataset,
                                     mixin_dataset,
                                     mix_interval=self.mix_interval,
                                     mix_fn=self.mix_fn)
        return dataset

    def _count_labels(self,
                      raw_datasets,
                      label_key,
                      label_counts=None,
                      reps=1):
        if label_counts is None:
            label_counts = defaultdict(lambda: 0)
        if isinstance(raw_datasets, list):
            labels = []
            for ds, ds_reps in raw_datasets:
                label_counts, cur_labels = self._count_labels(
                    ds,
                    label_key,
                    label_counts=label_counts,
                    reps=ds_reps * reps)
                labels.append(cur_labels)
            return label_counts, labels

        labels = []
        for example in raw_datasets:
            cur_labels = sorted(set(to_list(example[label_key])))
            labels.append(cur_labels)
            for label in cur_labels:
                label_counts[label] += reps
        # print(label_counts)
        return label_counts, labels

    @staticmethod
    def _compute_label_repetitions(label_counts, min_counts):
        max_count = max(label_counts.values())
        if isinstance(min_counts, float):
            assert 0. < min_counts < 1., min_counts
            min_counts = math.ceil(max_count * min_counts)
        assert isinstance(min_counts, int) and min_counts > 1, min_counts
        assert min_counts - 1 <= 0.9 * max_count, (min_counts, max_count)
        base_rep = 1 // (1 - (min_counts - 1) / max_count)
        min_counts *= base_rep
        label_repetitions = {
            label: math.ceil(min_counts / count)
            for label, count in label_counts.items()
        }
        return label_repetitions

    def _build_repetition_groups(self, dataset, labels, label_repetitions):
        assert len(dataset) == len(labels), (len(dataset), len(labels))
        if isinstance(dataset, list):
            return [(group_ds, ds_reps * group_reps)
                    for (ds, ds_reps), cur_labels in zip(dataset, labels)
                    for group_ds, group_reps in self._build_repetition_groups(
                        ds, cur_labels, label_repetitions)]
        idx_reps = [
            max([label_repetitions[label] for label in idx_labels])
            for idx_labels in labels
        ]
        rep_groups = {}
        for n_reps in set(idx_reps):
            rep_groups[n_reps] = np.argwhere(
                np.array(idx_reps) == n_reps).flatten().tolist()
        datasets = []
        for n_reps, indices in sorted(rep_groups.items(), key=lambda x: x[0]):
            datasets.append((dataset[sorted(indices)], n_reps))
        # ds = lazy_dataset.intersperse(*datasets)
        return datasets

    def segment_transform_and_fetch(
        self,
        dataset,
        segment=True,
        transform=True,
        fetch=True,
        train=False,
    ):
        segmenter = self.train_segmenter if train else self.test_segmenter
        segment = segment and segmenter is not None
        if segment:
            dataset = dataset.map(segmenter)
        if transform:
            transform = self.train_transform if train else self.test_transform
            assert transform is not None
            if segment:
                dataset = dataset.batch_map(transform)
            else:
                dataset = dataset.map(transform)
        if fetch:
            fetcher = self.train_fetcher if train else self.test_fetcher
            assert fetcher is not None
            dataset = fetcher(dataset, batched_input=segment)
        return dataset

    @classmethod
    def finalize_dogmatic_config(cls, config):
        config['audio_reader'] = {
            'factory': AudioReader,
            'source_sample_rate': None,
            'target_sample_rate': 16000,
            'average_channels': True,
            'normalization_domain': 'instance',
            'normalization_type': 'max',
            'alignment_keys': ['events'],
        }
        config['train_transform'] = {
            'factory': Transform,
            'stft': {
                'factory': STFT,
                'shift': 320,
                'window_length': 960,
                'size': 1024,
                'fading': 'half',
                'pad': True,
                'alignment_keys': ['events'],
            },
            'label_encoder': {
                'factory': MultiHotAlignmentEncoder,
                'label_key': 'events',
                'storage_dir': config['storage_dir'],
            },
            'anchor_sampling_fn': {
                'factory': Uniform,
                'low': 0.4,
                'high': 0.6,
            },
            'anchor_shift_sampling_fn': {
                'factory': Uniform,
                'low': -0.1,
                'high': 0.1,
            },
        }
        config['test_transform'] = {
            'factory': Transform,
            'stft': config['train_transform']['stft'],
            'label_encoder': config['train_transform']['label_encoder'],
        }
        config['train_fetcher'] = {
            'factory': DataFetcher,
            'prefetch_workers': 16,
            'batch_size': 16,
            'max_padding_rate': .05,
            'drop_incomplete': True,
            'global_shuffle': False,  # already shuffled in prepare_audio
        }
        config['train_fetcher']['bucket_expiration'] = (
            2000 * config['train_fetcher']['batch_size'])
        config['test_fetcher'] = {
            'factory': DataFetcher,
            'prefetch_workers': config['train_fetcher']['prefetch_workers'],
            'batch_size': 2 * config['train_fetcher']['batch_size'],
            'max_padding_rate': config['train_fetcher']['max_padding_rate'],
            'bucket_expiration': config['train_fetcher']['bucket_expiration'],
            'drop_incomplete': False,
            'global_shuffle': False,
        }
        config['scale_sampling_fn'] = {
            'factory': LogTruncatedNormal,
            'loc': 0.,
            'scale': 1.,
            'truncation': np.log(3.),
        }
        if config['mix_interval'] is not None:
            config['mix_fn'] = {
                'factory': SuperposeEvents,
                'min_overlap': 1.,
                'fade_length':
                config['train_transform']['stft']['window_length'],
                'label_key': 'events',
            }
Ejemplo n.º 26
0
def main(_run, exp_dir, storage_dir, database_json, test_set, max_examples,
         device):
    if IS_MASTER:
        commands.print_config(_run)

    exp_dir = Path(exp_dir)
    storage_dir = Path(storage_dir)
    audio_dir = storage_dir / 'audio'
    audio_dir.mkdir(parents=True)

    config = load_json(exp_dir / 'config.json')

    model = Model.from_storage_dir(exp_dir, consider_mpi=True)
    model.to(device)
    model.eval()

    db = JsonDatabase(database_json)
    test_data = db.get_dataset(test_set)
    if max_examples is not None:
        test_data = test_data.shuffle(
            rng=np.random.RandomState(0))[:max_examples]
    test_data = prepare_dataset(test_data,
                                audio_reader=config['audio_reader'],
                                stft=config['stft'],
                                max_length=None,
                                batch_size=1,
                                shuffle=True)
    squared_err = list()
    with torch.no_grad():
        for example in split_managed(test_data,
                                     is_indexable=False,
                                     progress_bar=True,
                                     allow_single_worker=True):
            example = model.example_to_device(example, device)
            target = example['audio_data'].squeeze(1)
            x = model.feature_extraction(example['stft'], example['seq_len'])
            x = model.wavenet.infer(
                x.squeeze(1),
                chunk_length=80_000,
                chunk_overlap=16_000,
            )
            assert target.shape == x.shape, (target.shape, x.shape)
            squared_err.extend([(ex_id, mse.cpu().detach().numpy(), x.shape[1])
                                for ex_id, mse in zip(example['example_id'], ((
                                    x - target)**2).sum(1))])

    squared_err_list = COMM.gather(squared_err, root=MASTER)

    if IS_MASTER:
        print(f'\nlen(squared_err_list): {len(squared_err_list)}')
        squared_err = []
        for i in range(len(squared_err_list)):
            squared_err.extend(squared_err_list[i])
        _, err, t = list(zip(*squared_err))
        print('rmse:', np.sqrt(np.sum(err) / np.sum(t)))
        rmse = sorted([(ex_id, np.sqrt(err / t))
                       for ex_id, err, t in squared_err],
                      key=lambda x: x[1])
        dump_json(rmse, storage_dir / 'rmse.json', indent=4, sort_keys=False)
        ex_ids_ordered = [x[0] for x in rmse]
        test_data = db.get_dataset('test_clean').shuffle(
            rng=np.random.RandomState(0))[:max_examples].filter(lambda x: x[
                'example_id'] in ex_ids_ordered[:10] + ex_ids_ordered[-10:],
                                                                lazy=False)
        test_data = prepare_dataset(test_data,
                                    audio_reader=config['audio_reader'],
                                    stft=config['stft'],
                                    max_length=10.,
                                    batch_size=1,
                                    shuffle=True)
        with torch.no_grad():
            for example in test_data:
                example = model.example_to_device(example, device)
                x = model.feature_extraction(example['stft'],
                                             example['seq_len'])
                x = model.wavenet.infer(
                    x.squeeze(1),
                    chunk_length=80_000,
                    chunk_overlap=16_000,
                )
                for i, audio in enumerate(x.cpu().detach().numpy()):
                    wavfile.write(
                        str(audio_dir / f'{example["example_id"][i]}.wav'),
                        model.sample_rate, audio)
Ejemplo n.º 27
0
def main(_run, datasets, debug, experiment_dir, dump_audio, sample_rate, _log,
         database_json):
    experiment_dir = Path(experiment_dir)

    if dlp_mpi.IS_MASTER:
        sacred.commands.print_config(_run)
        dump_config_and_makefile()

    model = get_model()
    db = JsonDatabase(database_json)

    model.eval()
    results = defaultdict(dict)
    with torch.no_grad():
        for dataset in datasets:
            iterable = prepare_dataset(
                db,
                dataset,
                1,
                chunk_size=-1,
                prefetch=False,
                shuffle=False,
                dataset_slice=slice(dlp_mpi.RANK, 20 if debug else None,
                                    dlp_mpi.SIZE),
            )

            if dump_audio:
                (experiment_dir / 'audio' / dataset).mkdir(parents=True,
                                                           exist_ok=True)

            for batch in tqdm(iterable,
                              total=len(iterable),
                              disable=not dlp_mpi.IS_MASTER):
                example_id = batch['example_id'][0]
                results[dataset][example_id] = entry = dict()

                try:
                    model_output = model(model.example_to_device(batch))

                    # Bring to numpy float64 for evaluation metrics computation
                    observation = batch['y'][0].astype(np.float64)[None, ]
                    speech_prediction = (
                        model_output['out'][0].cpu().numpy().astype(
                            np.float64))
                    speech_source = batch['s'][0].astype(np.float64)

                    input_metrics = pb_bss.evaluation.InputMetrics(
                        observation=observation,
                        speech_source=speech_source,
                        sample_rate=sample_rate,
                        enable_si_sdr=True,
                    )

                    output_metrics = pb_bss.evaluation.OutputMetrics(
                        speech_prediction=speech_prediction,
                        speech_source=speech_source,
                        sample_rate=sample_rate,
                        enable_si_sdr=True,
                    )

                    # Select the metrics to compute
                    entry['input'] = dict(
                        mir_eval=input_metrics.mir_eval,
                        si_sdr=input_metrics.si_sdr,
                        # TODO: stoi fails with short speech segments (https://github.com/mpariente/pystoi/issues/21)
                        # stoi=input_metrics.stoi,
                        # TODO: pesq creates "Processing error" messages
                        # pesq=input_metrics.pesq,
                    )

                    # Remove selection from mir_eval dict to enable
                    # recursive calculation of improvement
                    entry['output'] = dict(
                        mir_eval={
                            k: v
                            for k, v in output_metrics.mir_eval.items()
                            if k != 'selection'
                        },
                        si_sdr=output_metrics.si_sdr,
                        # stoi=output_metrics.stoi,
                        # pesq=output_metrics.pesq,
                    )

                    entry['improvement'] = pb.utils.nested.nested_op(
                        operator.sub,
                        entry['output'],
                        entry['input'],
                    )
                    entry['selection'] = output_metrics.mir_eval['selection']

                    if dump_audio:
                        entry['audio_path'] = batch['audio_path']
                        entry['audio_path']['estimated'] = []
                        for k, audio in enumerate(speech_prediction):
                            audio_path = (experiment_dir / 'audio' / dataset /
                                          f'{example_id}_{k}.wav')
                            pb.io.dump_audio(audio,
                                             audio_path,
                                             sample_rate=sample_rate)
                            entry['audio_path']['estimated'].append(audio_path)
                except:
                    _log.error(f'Exception was raised in example with ID '
                               f'"{example_id}"')
                    raise

    results = dlp_mpi.gather(results, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        # Combine all results to one. This function raises an exception if it
        # finds duplicate keys
        results = pb.utils.nested.nested_merge(*results)

        for dataset, values in results.items():
            _log.info(f'{dataset}: {len(values)}')

        # Write results to JSON
        result_json_path = experiment_dir / 'result.json'
        _log.info(f"Exporting result: {result_json_path}")
        pb.io.dump_json(results, result_json_path)

        # Compute means for some metrics
        means = compute_means(results)
        mean_json_path = experiment_dir / 'means.json'
        _log.info(f'Exporting means: {mean_json_path}')
        pb.io.dump_json(means, mean_json_path)

        _log.info('Resulting means:')

        pprint(means)
Ejemplo n.º 28
0
def write_wavs(dst_dir, json_path, write_all=False, snr_range=(20, 30)):
    db = JsonDatabase(json_path)
    if write_all:
        if dlp_mpi.IS_MASTER:
            [(dst_dir / data_type).mkdir(exist_ok=False)
             for data_type in KEY_MAPPER.values()]
        map_fn = partial(scenario_map_fn,
                         snr_range=snr_range,
                         sync_speech_source=True,
                         add_speech_reverberation_early=True,
                         add_speech_reverberation_tail=True)
    else:
        if dlp_mpi.IS_MASTER:
            (dst_dir / 'observation').mkdir(exist_ok=False)
        map_fn = partial(scenario_map_fn,
                         snr_range=snr_range,
                         sync_speech_source=True,
                         add_speech_reverberation_early=False,
                         add_speech_reverberation_tail=False)
    for dataset in ['train_si284', 'cv_dev93', 'test_eval92']:
        if dlp_mpi.IS_MASTER:
            [(dst_dir / data_type / dataset).mkdir(exist_ok=False)
             for data_type in KEY_MAPPER.values()]
        ds = db.get_dataset(dataset).map(audio_read).map(map_fn)
        for example in dlp_mpi.split_managed(
                ds,
                is_indexable=True,
                allow_single_worker=True,
                progress_bar=True,
        ):
            audio_dict = example['audio_data']
            example_id = example['example_id']
            if not write_all:
                del audio_dict['speech_reverberation_early']
                del audio_dict['speech_reverberation_tail']
                del audio_dict['noise_image']

            def get_abs_max(a):
                if isinstance(a, np.ndarray):
                    if a.dtype == np.object:
                        return np.max(list(map(get_abs_max, a)))
                    else:
                        return np.max(np.abs(a))
                elif isinstance(a, (tuple, list)):
                    return np.max(list(map(get_abs_max, a)))
                elif isinstance(a, dict):
                    return np.max(list(map(get_abs_max, a.values())))
                else:
                    raise TypeError(a)

            assert get_abs_max(audio_dict), (example_id, {
                k: get_abs_max(v)
                for k, v in audio_dict.items()
            })
            for key, value in audio_dict.items():
                if key not in KEY_MAPPER:
                    continue
                path = dst_dir / KEY_MAPPER[key] / dataset
                if key in ['observation', 'noise_image']:
                    value = value[None]
                for idx, signal in enumerate(value):
                    appendix = f'_{idx}' if len(value) > 1 else ''
                    filename = example_id + appendix + '.wav'
                    audio_path = str(path / filename)
                    with soundfile.SoundFile(audio_path,
                                             subtype='FLOAT',
                                             mode='w',
                                             samplerate=8000,
                                             channels=1 if signal.ndim == 1
                                             else signal.shape[0]) as f:
                        f.write(signal.T)

        dlp_mpi.barrier()

    if dlp_mpi.IS_MASTER:
        created_files = check_files(dst_dir)
        print(f"Written {len(created_files)} wav files.")
        if write_all:
            # TODO Less, if you do a test run.
            num_speakers = 2  # todo infer num_speakers from json
            # 2 files for: early, tail, speech_source
            # 1 file for: observation, noise
            expect = (3 * num_speakers + 2) * 35875
            assert len(created_files) == expect, (len(created_files), expect)
        else:
            assert len(created_files) == 35875, len(created_files)
Ejemplo n.º 29
0
 def __post_init__(self):
     assert self.json_path is not None
     self.db = JsonDatabase(json_path=self.json_path)
def main(
    json_path: Path,
    rir_dir: Path,
    wsj_json_path: Path,
    num_speakers: int,
    debug: bool,
):
    wsj_json_path = Path(wsj_json_path).expanduser().resolve()
    json_path = Path(json_path).expanduser().resolve()
    rir_dir = Path(rir_dir).expanduser().resolve()
    assert wsj_json_path.is_file(), json_path
    assert rir_dir.exists(), rir_dir

    # ToDo: What was the motivation for defining this "setup"?
    setup = dict(
        train_si284=dict(source_dataset_name="train_si284"),
        cv_dev93=dict(source_dataset_name="cv_dev93"),
        test_eval92=dict(source_dataset_name="test_eval92"),
    )

    rir_db = JsonDatabase(rir_dir / "scenarios.json")

    source_db = JsonDatabase(wsj_json_path)

    target_db = dict()
    target_db['datasets'] = defaultdict(dict)

    for dataset_name in setup.keys():
        source_dataset_name = setup[dataset_name]["source_dataset_name"]
        source_dataset = source_db.get_dataset(source_dataset_name)
        print(f'length of source {dataset_name}: {len(source_dataset)}')
        source_dataset = source_dataset.filter(
            filter_fn=filter_punctuation_pronunciation, lazy=False)
        print(f'length of source {dataset_name}: {len(source_dataset)} '
              '(after punctuation filter)')

        def add_rir_path(rir_ex):
            assert 'audio_path' not in rir_ex, rir_ex
            example_id = rir_ex['example_id']
            rir_ex['audio_path'] = {
                'rir': [
                    str(rir_dir / dataset_name / example_id / f"h_{k}.wav")
                    for k in range(num_speakers)
                ]
            }
            return rir_ex

        rir_dataset = rir_db.get_dataset(dataset_name).map(add_rir_path)

        assert len(rir_dataset) % len(source_dataset) == 0, (
            f'To avoid a bias towards certain utterance the len '
            f'rir_dataset ({len(rir_dataset)}) should be an integer '
            f'multiple of len source_dataset ({len(source_dataset)}).')

        print(f'length of rir {dataset_name}: {len(rir_dataset)}')

        probe_path = rir_dir / dataset_name / "0"
        available_speaker_positions = len(list(probe_path.glob('h_*.wav')))
        assert num_speakers <= available_speaker_positions, (
            f'Requested {num_speakers} num_speakers, while found only '
            f'{available_speaker_positions} rirs in {probe_path}.')

        info = soundfile.info(str(rir_dir / dataset_name / "0" / "h_0.wav"))
        sample_rate_rir = info.samplerate

        ex_wsj = source_dataset.random_choice(1)[0]
        info = soundfile.SoundFile(ex_wsj['audio_path']['observation'])
        sample_rate_wsj = info.samplerate
        assert sample_rate_rir == sample_rate_wsj, (sample_rate_rir,
                                                    sample_rate_wsj)

        if debug:
            rir_dataset = rir_dataset[:DEBUG_EXAMPLE_LIMIT]
            # Use step_size to avoid that only one speaker is in
            # source_iterator.
            step_size = len(source_dataset) // DEBUG_EXAMPLE_LIMIT
            source_dataset = source_dataset[::step_size]

        ex_dict = combine_rirs_and_sources(
            rir_dataset=rir_dataset,
            source_dataset=source_dataset,
            num_speakers=num_speakers,
            dataset_name=dataset_name,
        )

        target_db['datasets'][dataset_name] = ex_dict

    json_path.parent.mkdir(exist_ok=True, parents=True)
    with json_path.open('w') as f:
        json.dump(target_db, f, indent=2, ensure_ascii=False)
    print(f'{json_path} written.')