Пример #1
0
def main(_run, _log, trainer, database_json, training_sets, validation_sets,
         audio_reader, stft, max_length_in_sec, batch_size, resume):
    commands.print_config(_run)
    trainer = Trainer.from_config(trainer)
    storage_dir = Path(trainer.storage_dir)
    storage_dir.mkdir(parents=True, exist_ok=True)
    commands.save_config(_run.config,
                         _log,
                         config_filename=str(storage_dir / 'config.json'))

    db = JsonDatabase(database_json)
    training_data = db.get_dataset(training_sets)
    validation_data = db.get_dataset(validation_sets)
    training_data = prepare_dataset(training_data,
                                    audio_reader=audio_reader,
                                    stft=stft,
                                    max_length_in_sec=max_length_in_sec,
                                    batch_size=batch_size,
                                    shuffle=True)
    validation_data = prepare_dataset(validation_data,
                                      audio_reader=audio_reader,
                                      stft=stft,
                                      max_length_in_sec=max_length_in_sec,
                                      batch_size=batch_size,
                                      shuffle=False)

    trainer.test_run(training_data, validation_data)
    trainer.register_validation_hook(validation_data)
    trainer.train(training_data, resume=resume)
Пример #2
0
def get_datasets(storage_dir,
                 database_json,
                 dataset,
                 batch_size=16,
                 return_indexable=False):
    db = JsonDatabase(database_json)
    ds = db.get_dataset(dataset)

    def prepare_example(example):
        example['audio_path'] = example['audio_path']['observation']
        example['speaker_id'] = example['speaker_id'].split('-')[0]
        return example

    ds = ds.map(prepare_example)

    speaker_encoder = LabelEncoder(label_key='speaker_id',
                                   storage_dir=storage_dir,
                                   to_array=True)
    speaker_encoder.initialize_labels(dataset=ds, verbose=True)
    ds = ds.map(speaker_encoder)

    # LibriSpeech (the default database) does not share speakers across
    # different datasets, i.e., the datasets, e.g. clean_100 and dev_clean, have
    # a different set of non-overlapping speakers. To guarantee the same set of
    # speakers during training, validation and evaluation, we perform a split of
    # the train set, e.g., clean_100 or clean_360.
    train_set, validate_set, test_set = train_test_split(ds)

    training_data = prepare_dataset(train_set, batch_size, training=True)
    validation_data = prepare_dataset(validate_set, batch_size, training=False)
    test_data = prepare_dataset(test_set,
                                batch_size,
                                training=False,
                                return_indexable=return_indexable)
    return training_data, validation_data, test_data
Пример #3
0
    def __init__(
        self,
        json_path,
        target,
        dset,
        sample_rate=8000,
        single_channel=True,
        segment=4.0,
        nondefault_nsrc=None,
        normalize_audio=False,
    ):
        try:
            import sms_wsj  # noqa
        except ModuleNotFoundError:
            import warnings

            warnings.warn(
                "Some of the functionality relies on the sms_wsj package "
                "downloadable from https://github.com/fgnt/sms_wsj ."
                "The user is encouraged to install the package")
        super().__init__()
        if target not in SMS_TARGETS.keys():
            raise ValueError("Unexpected task {}, expected one of "
                             "{}".format(target, SMS_TARGETS.keys()))

        # Task setting
        self.json_path = json_path
        self.target = target
        self.target_dict = SMS_TARGETS[target]
        self.single_channel = single_channel
        self.sample_rate = sample_rate
        self.normalize_audio = normalize_audio
        self.seg_len = None if segment is None else int(segment * sample_rate)
        if not nondefault_nsrc:
            self.n_src = self.target_dict["default_nsrc"]
        else:
            assert nondefault_nsrc >= self.target_dict["default_nsrc"]
            self.n_src = nondefault_nsrc
        self.like_test = self.seg_len is None
        self.dset = dset
        self.EPS = 1e-8

        # Load json files

        from lazy_dataset.database import JsonDatabase

        db = JsonDatabase(json_path)
        dataset = db.get_dataset(dset)
        # Filter out short utterances only when segment is specified
        if not self.like_test:

            def filter_short_examples(example):
                num_samples = example["num_samples"]["observation"]
                if num_samples < self.seg_len:
                    return False
                else:
                    return True

            dataset = dataset.filter(filter_short_examples, lazy=False)
        self.dataset = dataset
Пример #4
0
class DataProvider(Configurable):
    json_path: str
    audio_reader: Callable
    train_set: dict
    validate_set: str = None
    cached_datasets: list = None
    min_audio_length: float = 1.
    train_segmenter: float = None
    test_segmenter: float = None
    train_transform: Callable = None
    test_transform: Callable = None
    train_fetcher: Callable = None
    test_fetcher: Callable = None
    label_key: str = 'events'
    discard_labelless_train_examples: bool = True
    storage_dir: str = None
    # augmentation
    min_class_examples_per_epoch: int = 0
    scale_sampling_fn: Callable = None
    mix_interval: float = 1.5
    mix_fn: Callable = None

    def __post_init__(self):
        assert self.json_path is not None
        self.db = JsonDatabase(json_path=self.json_path)

    def get_train_set(self, filter_example_ids=None):
        return self.get_dataset(self.train_set,
                                train=True,
                                filter_example_ids=filter_example_ids)

    def get_validate_set(self, filter_example_ids=None):
        return self.get_dataset(self.validate_set,
                                train=False,
                                filter_example_ids=filter_example_ids)

    def get_dataset(self,
                    dataset_names_or_raw_datasets,
                    train=False,
                    filter_example_ids=None):
        ds = self.prepare_audio(dataset_names_or_raw_datasets,
                                train=train,
                                filter_example_ids=filter_example_ids)
        ds = self.segment_transform_and_fetch(ds, train=train)
        return ds

    def prepare_audio(self,
                      dataset_names_or_raw_datasets,
                      train=False,
                      filter_example_ids=None):
        individual_audio_datasets = self._load_audio(
            dataset_names_or_raw_datasets,
            train=train,
            filter_example_ids=filter_example_ids)
        if not isinstance(individual_audio_datasets, list):
            assert isinstance(
                individual_audio_datasets,
                lazy_dataset.Dataset), type(individual_audio_datasets)
            individual_audio_datasets = [(individual_audio_datasets, 1)]
        combined_audio_dataset = self._tile_and_intersperse(
            individual_audio_datasets, shuffle=train)
        if train and self.min_class_examples_per_epoch > 0:
            assert self.label_key is not None
            raw_datasets = self.get_raw(
                dataset_names_or_raw_datasets,
                discard_labelless_examples=self.
                discard_labelless_train_examples,
                filter_example_ids=filter_example_ids,
            )
            label_counts, labels = self._count_labels(raw_datasets,
                                                      self.label_key)
            label_reps = self._compute_label_repetitions(
                label_counts, min_counts=self.min_class_examples_per_epoch)
            repetition_groups = self._build_repetition_groups(
                individual_audio_datasets, labels, label_reps)
            dataset = self._tile_and_intersperse(repetition_groups,
                                                 shuffle=train)
        else:
            dataset = combined_audio_dataset
        if train:
            # dataset = self.scale_and_mix(dataset, combined_audio_dataset)
            dataset = self.scale_and_mix(dataset, dataset)
        print(f'Total data set length:', len(dataset))
        return dataset

    def _load_audio(self,
                    dataset_names_or_raw_datasets,
                    train=False,
                    filter_example_ids=None,
                    idx=None):
        if isinstance(dataset_names_or_raw_datasets, (dict, list, tuple)):
            ds = []
            for i, name_or_ds in enumerate(dataset_names_or_raw_datasets):
                num_reps = (
                    dataset_names_or_raw_datasets[name_or_ds] if isinstance(
                        dataset_names_or_raw_datasets, dict) else
                    name_or_ds[1] if isinstance(name_or_ds,
                                                (list, tuple)) else 1)
                if num_reps == 0:
                    continue
                ds.append((self._load_audio(
                    name_or_ds[0] if isinstance(name_or_ds,
                                                (list, tuple)) else name_or_ds,
                    train=train,
                    filter_example_ids=filter_example_ids,
                    idx=i,
                ), num_reps))
            return ds
        ds = self.get_raw(
            dataset_names_or_raw_datasets,
            discard_labelless_examples=(train and
                                        self.discard_labelless_train_examples),
            filter_example_ids=filter_example_ids,
        ).map(self.audio_reader)
        cache = (self.cached_datasets is not None
                 and isinstance(dataset_names_or_raw_datasets, str)
                 and dataset_names_or_raw_datasets in self.cached_datasets)
        if cache:
            ds = ds.cache(lazy=False)

        if isinstance(dataset_names_or_raw_datasets, str):
            ds_name = " " + dataset_names_or_raw_datasets
        else:
            ds_name = ""
        if idx is not None:
            ds_name += f" [{idx}]"
        print(f'Single data set length{ds_name}:', len(ds))
        return ds

    def get_raw(
        self,
        dataset_names_or_raw_datasets,
        discard_labelless_examples=False,
        filter_example_ids=None,
    ):
        if isinstance(dataset_names_or_raw_datasets, (dict, list, tuple)):
            return list(
                filter(lambda x: x[1] > 0, [(
                    self.get_raw(
                        name_or_ds[0] if isinstance(name_or_ds,
                                                    (list,
                                                     tuple)) else name_or_ds,
                        discard_labelless_examples=discard_labelless_examples,
                        filter_example_ids=filter_example_ids,
                    ),
                    (dataset_names_or_raw_datasets[name_or_ds] if isinstance(
                        dataset_names_or_raw_datasets, dict) else
                     name_or_ds[1] if isinstance(name_or_ds,
                                                 (list, tuple)) else 1),
                ) for name_or_ds in dataset_names_or_raw_datasets]))
        elif isinstance(dataset_names_or_raw_datasets, str):
            ds = self.db.get_dataset(dataset_names_or_raw_datasets)
        else:
            assert isinstance(
                dataset_names_or_raw_datasets,
                lazy_dataset.Dataset), type(dataset_names_or_raw_datasets)
            ds = dataset_names_or_raw_datasets
        if discard_labelless_examples:
            ds = ds.filter(
                lambda ex: self.label_key in ex and ex[self.label_key],
                lazy=False)
        if filter_example_ids is not None:
            ds = ds.filter(
                lambda ex: ex['example_id'] not in filter_example_ids,
                lazy=False)
        return ds.filter(lambda ex: 'audio_length' in ex and ex['audio_length']
                         > self.min_audio_length,
                         lazy=False)

    @staticmethod
    def _tile_and_intersperse(datasets, shuffle=False):
        if shuffle:
            datasets = [(ds.shuffle(reshuffle=True), reps)
                        for ds, reps in datasets]
        return lazy_dataset.intersperse(
            *[ds.tile(reps) for ds, reps in datasets])

    def scale_and_mix(self, dataset, mixin_dataset=None):
        if mixin_dataset is None:
            mixin_dataset = dataset
        if self.scale_sampling_fn is not None:

            def scale(example):
                w = self.scale_sampling_fn()
                example['audio_data'] = example['audio_data'] * w
                return example

            dataset = dataset.map(scale)
            mixin_dataset = mixin_dataset.map(scale)

        if self.mix_interval is not None:
            # mixin_dataset = mixin_dataset.tile(
            #     math.ceil(len(dataset)/len(combined_audio_dataset)))
            assert self.mix_fn is not None
            dataset = MixtureDataset(dataset,
                                     mixin_dataset,
                                     mix_interval=self.mix_interval,
                                     mix_fn=self.mix_fn)
        return dataset

    def _count_labels(self,
                      raw_datasets,
                      label_key,
                      label_counts=None,
                      reps=1):
        if label_counts is None:
            label_counts = defaultdict(lambda: 0)
        if isinstance(raw_datasets, list):
            labels = []
            for ds, ds_reps in raw_datasets:
                label_counts, cur_labels = self._count_labels(
                    ds,
                    label_key,
                    label_counts=label_counts,
                    reps=ds_reps * reps)
                labels.append(cur_labels)
            return label_counts, labels

        labels = []
        for example in raw_datasets:
            cur_labels = sorted(set(to_list(example[label_key])))
            labels.append(cur_labels)
            for label in cur_labels:
                label_counts[label] += reps
        # print(label_counts)
        return label_counts, labels

    @staticmethod
    def _compute_label_repetitions(label_counts, min_counts):
        max_count = max(label_counts.values())
        if isinstance(min_counts, float):
            assert 0. < min_counts < 1., min_counts
            min_counts = math.ceil(max_count * min_counts)
        assert isinstance(min_counts, int) and min_counts > 1, min_counts
        assert min_counts - 1 <= 0.9 * max_count, (min_counts, max_count)
        base_rep = 1 // (1 - (min_counts - 1) / max_count)
        min_counts *= base_rep
        label_repetitions = {
            label: math.ceil(min_counts / count)
            for label, count in label_counts.items()
        }
        return label_repetitions

    def _build_repetition_groups(self, dataset, labels, label_repetitions):
        assert len(dataset) == len(labels), (len(dataset), len(labels))
        if isinstance(dataset, list):
            return [(group_ds, ds_reps * group_reps)
                    for (ds, ds_reps), cur_labels in zip(dataset, labels)
                    for group_ds, group_reps in self._build_repetition_groups(
                        ds, cur_labels, label_repetitions)]
        idx_reps = [
            max([label_repetitions[label] for label in idx_labels])
            for idx_labels in labels
        ]
        rep_groups = {}
        for n_reps in set(idx_reps):
            rep_groups[n_reps] = np.argwhere(
                np.array(idx_reps) == n_reps).flatten().tolist()
        datasets = []
        for n_reps, indices in sorted(rep_groups.items(), key=lambda x: x[0]):
            datasets.append((dataset[sorted(indices)], n_reps))
        # ds = lazy_dataset.intersperse(*datasets)
        return datasets

    def segment_transform_and_fetch(
        self,
        dataset,
        segment=True,
        transform=True,
        fetch=True,
        train=False,
    ):
        segmenter = self.train_segmenter if train else self.test_segmenter
        segment = segment and segmenter is not None
        if segment:
            dataset = dataset.map(segmenter)
        if transform:
            transform = self.train_transform if train else self.test_transform
            assert transform is not None
            if segment:
                dataset = dataset.batch_map(transform)
            else:
                dataset = dataset.map(transform)
        if fetch:
            fetcher = self.train_fetcher if train else self.test_fetcher
            assert fetcher is not None
            dataset = fetcher(dataset, batched_input=segment)
        return dataset

    @classmethod
    def finalize_dogmatic_config(cls, config):
        config['audio_reader'] = {
            'factory': AudioReader,
            'source_sample_rate': None,
            'target_sample_rate': 16000,
            'average_channels': True,
            'normalization_domain': 'instance',
            'normalization_type': 'max',
            'alignment_keys': ['events'],
        }
        config['train_transform'] = {
            'factory': Transform,
            'stft': {
                'factory': STFT,
                'shift': 320,
                'window_length': 960,
                'size': 1024,
                'fading': 'half',
                'pad': True,
                'alignment_keys': ['events'],
            },
            'label_encoder': {
                'factory': MultiHotAlignmentEncoder,
                'label_key': 'events',
                'storage_dir': config['storage_dir'],
            },
            'anchor_sampling_fn': {
                'factory': Uniform,
                'low': 0.4,
                'high': 0.6,
            },
            'anchor_shift_sampling_fn': {
                'factory': Uniform,
                'low': -0.1,
                'high': 0.1,
            },
        }
        config['test_transform'] = {
            'factory': Transform,
            'stft': config['train_transform']['stft'],
            'label_encoder': config['train_transform']['label_encoder'],
        }
        config['train_fetcher'] = {
            'factory': DataFetcher,
            'prefetch_workers': 16,
            'batch_size': 16,
            'max_padding_rate': .05,
            'drop_incomplete': True,
            'global_shuffle': False,  # already shuffled in prepare_audio
        }
        config['train_fetcher']['bucket_expiration'] = (
            2000 * config['train_fetcher']['batch_size'])
        config['test_fetcher'] = {
            'factory': DataFetcher,
            'prefetch_workers': config['train_fetcher']['prefetch_workers'],
            'batch_size': 2 * config['train_fetcher']['batch_size'],
            'max_padding_rate': config['train_fetcher']['max_padding_rate'],
            'bucket_expiration': config['train_fetcher']['bucket_expiration'],
            'drop_incomplete': False,
            'global_shuffle': False,
        }
        config['scale_sampling_fn'] = {
            'factory': LogTruncatedNormal,
            'loc': 0.,
            'scale': 1.,
            'truncation': np.log(3.),
        }
        if config['mix_interval'] is not None:
            config['mix_fn'] = {
                'factory': SuperposeEvents,
                'min_overlap': 1.,
                'fade_length':
                config['train_transform']['stft']['window_length'],
                'label_key': 'events',
            }
Пример #5
0
def _create_data_dir(
        get_wer_command_fn, kaldi_dir, db=None, json_path=None,
        dataset_names=None, data_type='wsj_8k', target_speaker=0,
        ref_channels=0,
):
    """

    Args:
        get_wer_command_fn:
        kaldi_dir:
        db:
        json_path:
        dataset_names:
        data_type:
        target_speaker:
        ref_channels:

    Returns:

    """

    assert not (db is None and json_path is None), (db, json_path)
    if db is None:
        db = JsonDatabase(json_path)

    kaldi_dir = Path(kaldi_dir).expanduser().resolve()

    data_dir = kaldi_dir / 'data' / data_type
    data_dir.mkdir(exist_ok=True, parents=True)

    if not isinstance(ref_channels, (list, tuple)):
        ref_channels = [ref_channels]
    example_id_to_wav = dict()
    example_id_to_speaker = dict()
    example_id_to_trans = dict()
    example_id_to_duration = dict()
    speaker_to_gender = defaultdict(lambda: defaultdict(list))
    dataset_to_example_id = defaultdict(list)

    if dataset_names is None:
        dataset_names = ('train_si284', 'cv_dev93', 'test_eval92')
    elif isinstance(dataset_names, str):
        dataset_names = [dataset_names]
    if not isinstance(target_speaker, (list, tuple)):
        target_speaker = [target_speaker]
    assert not any([
        (data_dir / dataset_name).exists() for dataset_name in dataset_names
    ]), (
        'One of the following directories already exists: '
        f'{[data_dir / ds_name for ds_name in dataset_names]}\n'
        'Delete them if you want to restart this stage'
    )

    print(
        'Create data dir for '
        f'{", ".join([f"{data_type}/{ds_name}" for ds_name in dataset_names])} '
        'data'
    )

    dataset = db.get_dataset(dataset_names)
    for example in dataset:
        for ref_ch in ref_channels:
            org_example_id = example['example_id']
            dataset_name = example['dataset']
            for t_spk in target_speaker:
                speaker_id = example['speaker_id'][t_spk]
                example_id = speaker_id + '_' + org_example_id
                example_id += f'_c{ref_ch}' if len(ref_channels) > 1 else ''
                example_id_to_wav[example_id] = get_wer_command_fn(
                    example, ref_ch=ref_ch, spk=t_spk)
                try:
                    transcription = example['kaldi_transcription'][t_spk]
                except KeyError:
                    transcription = example['transcription'][t_spk]
                example_id_to_trans[example_id] = transcription

                example_id_to_speaker[example_id] = speaker_id
                gender = example['gender'][t_spk]
                speaker_to_gender[dataset_name][speaker_id] = gender
                if isinstance(example['num_samples'], dict):
                    num_samples = example['num_samples']['observation']
                else:
                    num_samples = example['num_samples']
                example_id_to_duration[
                    example_id] = f"{num_samples / SAMPLE_RATE:.2f}"
                dataset_to_example_id[dataset_name].append(example_id)

    assert len(example_id_to_speaker) > 0, dataset
    for dataset_name in dataset_names:
        path = data_dir / dataset_name
        path.mkdir(exist_ok=False, parents=False)
        for name, dictionary in (
                ("utt2spk", example_id_to_speaker),
                ("text", example_id_to_trans),
                ("utt2dur", example_id_to_duration),
                ("wav.scp", example_id_to_wav)
        ):
            dictionary = {key: value for key, value in dictionary.items()
                          if key in dataset_to_example_id[dataset_name]}

            assert len(dictionary) > 0, (dataset_name, name)
            if name == 'utt2dur':
                dump_keyed_lines(dictionary, path / 'reco2dur')
            dump_keyed_lines(dictionary, path / name)
        dictionary = speaker_to_gender[dataset_name]
        assert len(dictionary) > 0, (dataset_name, name)
        dump_keyed_lines(dictionary, path / 'spk2gender')
        run_process([
            f'utils/fix_data_dir.sh', f'{path}'],
            cwd=str(kaldi_dir), stdout=None, stderr=None
        )
Пример #6
0
def get_test_dataset(database: JsonDatabase):
    val_iterator = database.get_dataset('et05_simu')
    return val_iterator.map(prepare_data)
Пример #7
0
def main(_run, batch_size, datasets, debug, experiment_dir, database_json,
         _log):
    experiment_dir = Path(experiment_dir)

    if dlp_mpi.IS_MASTER:
        sacred.commands.print_config(_run)

    model = get_model()
    db = JsonDatabase(json_path=database_json)

    model.eval()
    with torch.no_grad():
        summary = defaultdict(dict)
        for dataset_name in datasets:
            dataset = prepare_dataset(db,
                                      dataset_name,
                                      batch_size,
                                      return_keys=None,
                                      prefetch=False,
                                      shuffle=False)

            for batch in dlp_mpi.split_managed(dataset,
                                               is_indexable=True,
                                               progress_bar=True,
                                               allow_single_worker=debug):
                entry = dict()
                model_output = model(model.example_to_device(batch))

                example_id = batch['example_id'][0]
                s = batch['s'][0]
                Y = batch['Y'][0]
                mask = model_output[0].numpy()

                Z = mask * Y[:, None, :]
                z = istft(einops.rearrange(Z, "t k f -> k t f"),
                          size=512,
                          shift=128)

                s = s[:, :z.shape[1]]
                z = z[:, :s.shape[1]]

                input_metrics = pb_bss.evaluation.InputMetrics(
                    observation=batch['y'][0][None, :],
                    speech_source=s,
                    sample_rate=8000,
                    enable_si_sdr=False,
                )

                output_metrics = pb_bss.evaluation.OutputMetrics(
                    speech_prediction=z,
                    speech_source=s,
                    sample_rate=8000,
                    enable_si_sdr=False,
                )
                entry['input'] = dict(mir_eval=input_metrics.mir_eval, )
                entry['output'] = dict(mir_eval={
                    k: v
                    for k, v in output_metrics.mir_eval.items()
                    if k != 'selection'
                }, )

                entry['improvement'] = pb.utils.nested.nested_op(
                    operator.sub,
                    entry['output'],
                    entry['input'],
                )
                entry['selection'] = output_metrics.mir_eval['selection']

                summary[dataset][example_id] = entry

    summary_list = dlp_mpi.gather(summary, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        _log.info(f'len(summary_list): {len(summary_list)}')
        summary = pb.utils.nested.nested_merge(*summary_list)

        for dataset, values in summary.items():
            _log.info(f'{dataset}: {len(values)}')
            assert len(values) == len(
                db.get_dataset(dataset)
            ), 'Number of results needs to match length of dataset!'
        result_json_path = experiment_dir / 'result.json'
        _log.info(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)

        # Compute and save mean of metrics
        means = compute_means(summary)
        mean_json_path = experiment_dir / 'means.json'
        _log.info(f"Saving means to: {mean_json_path}")
        pb.io.dump_json(means, mean_json_path)
Пример #8
0
def main(_run, exp_dir, storage_dir, database_json, test_set, max_examples,
         device):
    if IS_MASTER:
        commands.print_config(_run)

    exp_dir = Path(exp_dir)
    storage_dir = Path(storage_dir)
    audio_dir = storage_dir / 'audio'
    audio_dir.mkdir(parents=True)

    config = load_json(exp_dir / 'config.json')

    model = Model.from_storage_dir(exp_dir, consider_mpi=True)
    model.to(device)
    model.eval()

    db = JsonDatabase(database_json)
    test_data = db.get_dataset(test_set)
    if max_examples is not None:
        test_data = test_data.shuffle(
            rng=np.random.RandomState(0))[:max_examples]
    test_data = prepare_dataset(test_data,
                                audio_reader=config['audio_reader'],
                                stft=config['stft'],
                                max_length=None,
                                batch_size=1,
                                shuffle=True)
    squared_err = list()
    with torch.no_grad():
        for example in split_managed(test_data,
                                     is_indexable=False,
                                     progress_bar=True,
                                     allow_single_worker=True):
            example = model.example_to_device(example, device)
            target = example['audio_data'].squeeze(1)
            x = model.feature_extraction(example['stft'], example['seq_len'])
            x = model.wavenet.infer(
                x.squeeze(1),
                chunk_length=80_000,
                chunk_overlap=16_000,
            )
            assert target.shape == x.shape, (target.shape, x.shape)
            squared_err.extend([(ex_id, mse.cpu().detach().numpy(), x.shape[1])
                                for ex_id, mse in zip(example['example_id'], ((
                                    x - target)**2).sum(1))])

    squared_err_list = COMM.gather(squared_err, root=MASTER)

    if IS_MASTER:
        print(f'\nlen(squared_err_list): {len(squared_err_list)}')
        squared_err = []
        for i in range(len(squared_err_list)):
            squared_err.extend(squared_err_list[i])
        _, err, t = list(zip(*squared_err))
        print('rmse:', np.sqrt(np.sum(err) / np.sum(t)))
        rmse = sorted([(ex_id, np.sqrt(err / t))
                       for ex_id, err, t in squared_err],
                      key=lambda x: x[1])
        dump_json(rmse, storage_dir / 'rmse.json', indent=4, sort_keys=False)
        ex_ids_ordered = [x[0] for x in rmse]
        test_data = db.get_dataset('test_clean').shuffle(
            rng=np.random.RandomState(0))[:max_examples].filter(lambda x: x[
                'example_id'] in ex_ids_ordered[:10] + ex_ids_ordered[-10:],
                                                                lazy=False)
        test_data = prepare_dataset(test_data,
                                    audio_reader=config['audio_reader'],
                                    stft=config['stft'],
                                    max_length=10.,
                                    batch_size=1,
                                    shuffle=True)
        with torch.no_grad():
            for example in test_data:
                example = model.example_to_device(example, device)
                x = model.feature_extraction(example['stft'],
                                             example['seq_len'])
                x = model.wavenet.infer(
                    x.squeeze(1),
                    chunk_length=80_000,
                    chunk_overlap=16_000,
                )
                for i, audio in enumerate(x.cpu().detach().numpy()):
                    wavfile.write(
                        str(audio_dir / f'{example["example_id"][i]}.wav'),
                        model.sample_rate, audio)
def main(
    json_path: Path,
    rir_dir: Path,
    wsj_json_path: Path,
    num_speakers: int,
    debug: bool,
):
    wsj_json_path = Path(wsj_json_path).expanduser().resolve()
    json_path = Path(json_path).expanduser().resolve()
    rir_dir = Path(rir_dir).expanduser().resolve()
    assert wsj_json_path.is_file(), json_path
    assert rir_dir.exists(), rir_dir

    # ToDo: What was the motivation for defining this "setup"?
    setup = dict(
        train_si284=dict(source_dataset_name="train_si284"),
        cv_dev93=dict(source_dataset_name="cv_dev93"),
        test_eval92=dict(source_dataset_name="test_eval92"),
    )

    rir_db = JsonDatabase(rir_dir / "scenarios.json")

    source_db = JsonDatabase(wsj_json_path)

    target_db = dict()
    target_db['datasets'] = defaultdict(dict)

    for dataset_name in setup.keys():
        source_dataset_name = setup[dataset_name]["source_dataset_name"]
        source_dataset = source_db.get_dataset(source_dataset_name)
        print(f'length of source {dataset_name}: {len(source_dataset)}')
        source_dataset = source_dataset.filter(
            filter_fn=filter_punctuation_pronunciation, lazy=False)
        print(f'length of source {dataset_name}: {len(source_dataset)} '
              '(after punctuation filter)')

        def add_rir_path(rir_ex):
            assert 'audio_path' not in rir_ex, rir_ex
            example_id = rir_ex['example_id']
            rir_ex['audio_path'] = {
                'rir': [
                    str(rir_dir / dataset_name / example_id / f"h_{k}.wav")
                    for k in range(num_speakers)
                ]
            }
            return rir_ex

        rir_dataset = rir_db.get_dataset(dataset_name).map(add_rir_path)

        assert len(rir_dataset) % len(source_dataset) == 0, (
            f'To avoid a bias towards certain utterance the len '
            f'rir_dataset ({len(rir_dataset)}) should be an integer '
            f'multiple of len source_dataset ({len(source_dataset)}).')

        print(f'length of rir {dataset_name}: {len(rir_dataset)}')

        probe_path = rir_dir / dataset_name / "0"
        available_speaker_positions = len(list(probe_path.glob('h_*.wav')))
        assert num_speakers <= available_speaker_positions, (
            f'Requested {num_speakers} num_speakers, while found only '
            f'{available_speaker_positions} rirs in {probe_path}.')

        info = soundfile.info(str(rir_dir / dataset_name / "0" / "h_0.wav"))
        sample_rate_rir = info.samplerate

        ex_wsj = source_dataset.random_choice(1)[0]
        info = soundfile.SoundFile(ex_wsj['audio_path']['observation'])
        sample_rate_wsj = info.samplerate
        assert sample_rate_rir == sample_rate_wsj, (sample_rate_rir,
                                                    sample_rate_wsj)

        if debug:
            rir_dataset = rir_dataset[:DEBUG_EXAMPLE_LIMIT]
            # Use step_size to avoid that only one speaker is in
            # source_iterator.
            step_size = len(source_dataset) // DEBUG_EXAMPLE_LIMIT
            source_dataset = source_dataset[::step_size]

        ex_dict = combine_rirs_and_sources(
            rir_dataset=rir_dataset,
            source_dataset=source_dataset,
            num_speakers=num_speakers,
            dataset_name=dataset_name,
        )

        target_db['datasets'][dataset_name] = ex_dict

    json_path.parent.mkdir(exist_ok=True, parents=True)
    with json_path.open('w') as f:
        json.dump(target_db, f, indent=2, ensure_ascii=False)
    print(f'{json_path} written.')
Пример #10
0
def main(json_path: Path, rir_dir: Path, wsj_json_path: Path, num_speakers):
    wsj_json_path = Path(wsj_json_path).expanduser().resolve()
    json_path = Path(json_path).expanduser().resolve()
    if json_path.exists():
        raise FileExistsError(json_path)
    rir_dir = Path(rir_dir).expanduser().resolve()
    assert wsj_json_path.is_file(), json_path
    assert rir_dir.exists(), rir_dir

    setup = dict(
        train_si284=dict(source_dataset_name="train_si284"),
        cv_dev93=dict(source_dataset_name="cv_dev93"),
        test_eval92=dict(source_dataset_name="test_eval92"),
    )

    rir_db = JsonDatabase(rir_dir / "scenarios.json")

    source_db = JsonDatabase(wsj_json_path)

    target_db = dict()
    target_db['datasets'] = defaultdict(dict)

    for dataset_name in setup.keys():
        source_dataset_name = setup[dataset_name]["source_dataset_name"]
        source_iterator = source_db.get_dataset(source_dataset_name)
        print(f'length of source {dataset_name}: {len(source_iterator)}')
        source_iterator = source_iterator.filter(
            filter_fn=filter_punctuation_pronunciation, lazy=False)
        print(f'length of source {dataset_name}: {len(source_iterator)} '
              '(after punctuation filter)')

        rir_iterator = rir_db.get_dataset(dataset_name)

        assert len(rir_iterator) % len(source_iterator) == 0, (
            f'To avoid a bias towards certain utterance the len '
            f'rir_iterator ({len(rir_iterator)}) should be an integer '
            f'multiple of len source_iterator ({len(source_iterator)}).')

        print(f'length of rir {dataset_name}: {len(rir_iterator)}')

        probe_path = rir_dir / dataset_name / "0"
        available_speaker_positions = len(list(probe_path.glob('h_*.wav')))
        assert num_speakers <= available_speaker_positions, (
            f'Requested {num_speakers} num_speakers, while found only '
            f'{available_speaker_positions} rirs in {probe_path}.')

        info = soundfile.info(str(rir_dir / dataset_name / "0" / "h_0.wav"))
        sample_rate_rir = info.samplerate

        ex_wsj = source_iterator.random_choice(1)[0]
        info = soundfile.SoundFile(ex_wsj['audio_path']['observation'])
        sample_rate_wsj = info.samplerate
        assert sample_rate_rir == sample_rate_wsj, (sample_rate_rir,
                                                    sample_rate_wsj)

        rir_iterator = rir_iterator.sort(
            sort_fn=functools.partial(sorted, key=int))

        source_iterator = source_iterator.sort()
        assert len(rir_iterator) % len(source_iterator) == 0
        repeats = len(rir_iterator) // len(source_iterator)
        source_iterator = source_iterator.tile(repeats)

        speaker_ids = [example['speaker_id'] for example in source_iterator]

        rng = get_rng(dataset_name, 'example_compositions')

        example_compositions = None
        for _ in range(num_speakers):
            example_compositions = extend_example_composition_greedy(
                rng,
                speaker_ids,
                example_compositions=example_compositions,
            )

        ex_dict = dict()
        assert len(rir_iterator) == len(example_compositions)
        for rir_example, example_composition in zip(rir_iterator,
                                                    example_compositions):
            source_examples = source_iterator[example_composition]

            example_id = "_".join([
                rir_example['example_id'],
                *[ex["example_id"] for ex in source_examples],
            ])

            rng = get_rng(dataset_name, example_id)
            example = get_randomized_example(
                rir_example,
                source_examples,
                rng,
                dataset_name,
                rir_dir,
            )
            ex_dict[example_id] = example

        target_db['datasets'][dataset_name] = ex_dict

    json_path.parent.mkdir(exist_ok=True, parents=True)
    with json_path.open('w') as f:
        json.dump(target_db, f, indent=2, ensure_ascii=False)
    print(f'{json_path} written')
Пример #11
0
def write_wavs(dst_dir, json_path, write_all=False, snr_range=(20, 30)):
    db = JsonDatabase(json_path)
    if write_all:
        if dlp_mpi.IS_MASTER:
            [(dst_dir / data_type).mkdir(exist_ok=False)
             for data_type in KEY_MAPPER.values()]
        map_fn = partial(scenario_map_fn,
                         snr_range=snr_range,
                         sync_speech_source=True,
                         add_speech_reverberation_early=True,
                         add_speech_reverberation_tail=True)
    else:
        if dlp_mpi.IS_MASTER:
            (dst_dir / 'observation').mkdir(exist_ok=False)
        map_fn = partial(scenario_map_fn,
                         snr_range=snr_range,
                         sync_speech_source=True,
                         add_speech_reverberation_early=False,
                         add_speech_reverberation_tail=False)
    for dataset in ['train_si284', 'cv_dev93', 'test_eval92']:
        if dlp_mpi.IS_MASTER:
            [(dst_dir / data_type / dataset).mkdir(exist_ok=False)
             for data_type in KEY_MAPPER.values()]
        ds = db.get_dataset(dataset).map(audio_read).map(map_fn)
        for example in dlp_mpi.split_managed(
                ds,
                is_indexable=True,
                allow_single_worker=True,
                progress_bar=True,
        ):
            audio_dict = example['audio_data']
            example_id = example['example_id']
            if not write_all:
                del audio_dict['speech_reverberation_early']
                del audio_dict['speech_reverberation_tail']
                del audio_dict['noise_image']

            def get_abs_max(a):
                if isinstance(a, np.ndarray):
                    if a.dtype == np.object:
                        return np.max(list(map(get_abs_max, a)))
                    else:
                        return np.max(np.abs(a))
                elif isinstance(a, (tuple, list)):
                    return np.max(list(map(get_abs_max, a)))
                elif isinstance(a, dict):
                    return np.max(list(map(get_abs_max, a.values())))
                else:
                    raise TypeError(a)

            assert get_abs_max(audio_dict), (example_id, {
                k: get_abs_max(v)
                for k, v in audio_dict.items()
            })
            for key, value in audio_dict.items():
                if key not in KEY_MAPPER:
                    continue
                path = dst_dir / KEY_MAPPER[key] / dataset
                if key in ['observation', 'noise_image']:
                    value = value[None]
                for idx, signal in enumerate(value):
                    appendix = f'_{idx}' if len(value) > 1 else ''
                    filename = example_id + appendix + '.wav'
                    audio_path = str(path / filename)
                    with soundfile.SoundFile(audio_path,
                                             subtype='FLOAT',
                                             mode='w',
                                             samplerate=8000,
                                             channels=1 if signal.ndim == 1
                                             else signal.shape[0]) as f:
                        f.write(signal.T)

        dlp_mpi.barrier()

    if dlp_mpi.IS_MASTER:
        created_files = check_files(dst_dir)
        print(f"Written {len(created_files)} wav files.")
        if write_all:
            # TODO Less, if you do a test run.
            num_speakers = 2  # todo infer num_speakers from json
            # 2 files for: early, tail, speech_source
            # 1 file for: observation, noise
            expect = (3 * num_speakers + 2) * 35875
            assert len(created_files) == expect, (len(created_files), expect)
        else:
            assert len(created_files) == 35875, len(created_files)
Пример #12
0
def get_datasets(use_noisy, split, fold, curated_reps, mixup_probs, extractor,
                 augmenter, num_workers, batch_size, prefetch_buffer,
                 max_padding_rate, bucket_expiration, event_bucketing, debug):
    # prepare database
    database_json = jsons_dir / f'fsd_kaggle_2019_split{split}.json'
    db = JsonDatabase(database_json)

    def add_noisy_flag(example):
        example['is_noisy'] = example['dataset'] != 'train_curated'
        return example

    extractor = Extractor(**extractor)
    augmenter = Augmenter(extractor=extractor, **augmenter)

    curated_train_data = db.get_dataset('train_curated').map(add_noisy_flag)
    extractor.initialize_labels(curated_train_data)
    if debug:
        curated_train_data = curated_train_data.shuffle(
        )[:len(curated_train_data) // 10]
    extractor.initialize_norm(dataset_name='train_curated',
                              dataset=curated_train_data,
                              max_workers=num_workers)

    if fold is not None:
        curated_train_data, validation_set = split_dataset(curated_train_data,
                                                           fold=fold,
                                                           seed=0)
    else:
        validation_set = None

    if use_noisy:
        noisy_train_data = db.get_dataset('train_noisy').map(add_noisy_flag)
        if debug:
            noisy_train_data = noisy_train_data.shuffle(
            )[:len(noisy_train_data) // 10]
        extractor.initialize_norm(dataset_name='train_noisy',
                                  dataset=noisy_train_data,
                                  max_workers=num_workers)
        training_set = lazy_dataset.concatenate(curated_train_data,
                                                noisy_train_data)
    else:
        training_set = curated_train_data
    batch_norm_tuning_set = training_set

    if mixup_probs is not None:
        training_set = MixUpDataset(training_set,
                                    training_set,
                                    mixin_probs=mixup_probs)
    if curated_reps > 0:
        print('curated reps:', curated_reps)
        curated_train_data = lazy_dataset.from_dict({
            f'{example["example_id"]}_{i}': example
            for i in range(curated_reps) for example in curated_train_data
        })
        if mixup_probs is not None:
            curated_train_data = MixUpDataset(curated_train_data,
                                              curated_train_data,
                                              mixin_probs=mixup_probs)
        training_set = lazy_dataset.concatenate(training_set,
                                                curated_train_data)

    print('Length of training set', len(training_set))

    bucket_cls = EventTimeSeriesBucket if event_bucketing \
        else DynamicTimeSeriesBucket

    def prepare_iterable(dataset, drop_incomplete=False):
        return dataset.prefetch(
            num_workers=num_workers,
            buffer_size=prefetch_buffer,
            catch_filter_exception=True).batch_dynamic_bucket(
                bucket_cls=bucket_cls,
                batch_size=batch_size,
                len_key='seq_len',
                max_padding_rate=max_padding_rate,
                expiration=bucket_expiration,
                drop_incomplete=drop_incomplete,
                sort_key='seq_len',
                reverse_sort=True).map(Collate())

    training_set = prepare_iterable(
        training_set.map(augmenter).shuffle(reshuffle=True),
        drop_incomplete=True)
    batch_norm_tuning_set = prepare_iterable(
        batch_norm_tuning_set.map(extractor), drop_incomplete=True)
    if validation_set is not None:
        validation_set = prepare_iterable(validation_set.map(extractor))

    return training_set, validation_set, batch_norm_tuning_set
Пример #13
0
def get_datasets(
        use_noisy, split, relabeled, fold, curated_reps, mixup_probs,
        audio_reader, stft, mel_transform, augmenter, num_workers, batch_size,
        prefetch_buffer, max_padding_rate, bucket_expiration, event_bucketing,
        debug
):
    # prepare database
    database_json = jsons_dir / \
        f'fsd_kaggle_2019_split{split}{"_relabeled" if relabeled else ""}.json'
    db = JsonDatabase(database_json)

    def add_noisy_flag(example):
        example['is_noisy'] = example['dataset'] != 'train_curated'
        return example

    audio_reader = AudioReader(**audio_reader)
    stft = STFT(**stft)
    mel_transform = MelTransform(**mel_transform)
    normalizer = Normalizer(storage_dir=str(storage_dir))
    augmenter = Augmenter(**augmenter)

    curated_train_data = db.get_dataset('train_curated').map(add_noisy_flag)

    event_encoder = MultiHotLabelEncoder(
        label_key='events', storage_dir=storage_dir
    )
    event_encoder.initialize_labels(
        dataset=curated_train_data, verbose=True
    )

    if debug:
        curated_train_data = curated_train_data.shuffle()[:500]

    normalizer.initialize_norm(
        dataset_name='train_curated',
        dataset=curated_train_data.map(audio_reader).map(stft).map(mel_transform),
        max_workers=num_workers,
    )

    if fold is not None:
        curated_train_data, validation_set = split_dataset(
            curated_train_data, fold=fold, seed=0
        )
    else:
        validation_set = None

    if use_noisy:
        noisy_train_data = db.get_dataset('train_noisy').map(add_noisy_flag)
        if debug:
            noisy_train_data = noisy_train_data.shuffle()[:500]

        normalizer.initialize_norm(
            dataset_name='train_noisy',
            dataset=noisy_train_data.map(audio_reader).map(stft).map(mel_transform),
            max_workers=num_workers,
        )
        training_set = lazy_dataset.concatenate(curated_train_data, noisy_train_data)
    else:
        training_set = curated_train_data
    batch_norm_tuning_set = training_set

    if mixup_probs is not None:
        training_set = MixUpDataset(
            training_set, training_set, mixin_probs=mixup_probs
        )
    if curated_reps > 0:
        print('curated reps:', curated_reps)
        curated_train_data = lazy_dataset.from_dict({
            f'{example["example_id"]}_{i}': example
            for i in range(curated_reps)
            for example in curated_train_data
        })
        if mixup_probs is not None:
            curated_train_data = MixUpDataset(
                curated_train_data, curated_train_data, mixin_probs=mixup_probs
            )
        training_set = lazy_dataset.concatenate(
            training_set, curated_train_data
        )

    print('Length of training set', len(training_set))
    print('Length of validation set', len(validation_set))

    def finalize(example):
        x = example['features']
        example_ = {
            'example_id': example['example_id'],
            'dataset': example['dataset'],
            'is_noisy': np.array(example['is_noisy']).astype(np.float32),
            'features': x.astype(np.float32),
            'seq_len': x.shape[1],
        }
        if 'events' in example:
            example_['events'] = example['events']
        return example_

    bucket_cls = EventTimeSeriesBucket if event_bucketing \
        else DynamicTimeSeriesBucket

    def prepare_iterable(dataset, drop_incomplete=False):
        return dataset.map(event_encoder).map(finalize).prefetch(
            num_workers=num_workers, buffer_size=prefetch_buffer,
            catch_filter_exception=True
        ).batch_dynamic_bucket(
            bucket_cls=bucket_cls, batch_size=batch_size, len_key='seq_len',
            max_padding_rate=max_padding_rate, expiration=bucket_expiration,
            drop_incomplete=drop_incomplete, sort_key='seq_len',
            reverse_sort=True
        ).map(Collate())

    training_set = prepare_iterable(
        training_set.map(audio_reader).map(stft).map(mel_transform).map(normalizer).map(augmenter).shuffle(reshuffle=True),
        drop_incomplete=True
    )
    batch_norm_tuning_set = prepare_iterable(
        batch_norm_tuning_set.map(audio_reader).map(stft).map(mel_transform).map(normalizer),
        drop_incomplete=True
    )
    if validation_set is not None:
        validation_set = prepare_iterable(
            validation_set.map(audio_reader).map(stft).map(mel_transform).map(normalizer)
        )

    return training_set, validation_set, batch_norm_tuning_set
Пример #14
0
def get_validation_dataset(database: JsonDatabase):
    # AudioReader is a specialized function to read audio organized
    # in a json as described in pb.database.database
    val_iterator = database.get_dataset('dt05_simu')
    return val_iterator.map(prepare_data) \
        .prefetch(num_workers=4, buffer_size=4)
Пример #15
0
def get_train_dataset(database: JsonDatabase):
    train_ds = database.get_dataset('tr05_simu')
    return (train_ds.map(prepare_data).prefetch(num_workers=4, buffer_size=4))