Exemplo n.º 1
0
    def _setup_dataloader_from_config(self, config: Optional[Dict]):
        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        # Automatically inject args from model config to dataloader config
        audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate')
        audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='labels')

        shuffle = config['shuffle']
        device = 'gpu' if torch.cuda.is_available() else 'cpu'
        if config.get('use_dali', False):
            device_id = self.local_rank if device == 'gpu' else None
            dataset = audio_to_text_dataset.get_dali_char_dataset(
                config=config,
                shuffle=shuffle,
                device_id=device_id,
                global_rank=self.global_rank,
                world_size=self.world_size,
                preprocessor_cfg=self._cfg.preprocessor,
            )
            return dataset

        # Instantiate tarred dataset loader or normal dataset loader
        if config.get('is_tarred', False):
            if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or (
                'manifest_filepath' in config and config['manifest_filepath'] is None
            ):
                logging.warning(
                    "Could not load dataset as `manifest_filepath` was None or "
                    f"`tarred_audio_filepaths` is None. Provided config : {config}"
                )
                return None

            shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
            dataset = audio_to_text_dataset.get_tarred_char_dataset(
                config=config,
                shuffle_n=shuffle_n,
                global_rank=self.global_rank,
                world_size=self.world_size,
                augmentor=augmentor,
            )
            shuffle = False
        else:
            if 'manifest_filepath' in config and config['manifest_filepath'] is None:
                logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
                return None

            dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor)

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=config['batch_size'],
            collate_fn=dataset.collate_fn,
            drop_last=config.get('drop_last', False),
            shuffle=shuffle,
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )
Exemplo n.º 2
0
    def test_dali_char_vs_ref_dataset(self, test_data_dir):
        manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/an4_val.json'))

        num_samples = 10
        batch_size = 1
        device = 'gpu' if torch.cuda.is_available() else 'cpu'
        texts = []

        with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f:
            with open(manifest_path, 'r') as m:
                for ix, line in enumerate(m):
                    if ix >= num_samples:
                        break

                    line = line.replace("tests/data/", "tests/.data/").replace("\n", "")
                    f.write(f"{line}\n")

                    data = json.loads(line)
                    texts.append(data['text'])

            f.seek(0)

            preprocessor = {
                '_target_': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor',
                'dither': 0.0,
            }
            preprocessor_cfg = DictConfig(preprocessor)

            dataset_cfg = {
                'manifest_filepath': f.name,
                'sample_rate': 16000,
                'labels': self.labels,
                'batch_size': batch_size,
                'trim_silence': False,
                'max_duration': 16.7,
                'shuffle': False,
                'is_tarred': False,
            }
            dali_dataset = audio_to_text_dataset.get_dali_char_dataset(
                config=dataset_cfg,
                shuffle=False,
                device_id=0,
                global_rank=0,
                world_size=1,
                preprocessor_cfg=preprocessor_cfg,
            )
            ref_dataset = audio_to_text_dataset.get_char_dataset(config=dataset_cfg,)
            ref_dataloader = DataLoader(
                dataset=ref_dataset,
                batch_size=batch_size,
                collate_fn=ref_dataset.collate_fn,
                drop_last=False,
                shuffle=False,
                num_workers=0,
                pin_memory=False,
            )
            ref_preprocessor = EncDecCTCModel.from_config_dict(preprocessor_cfg)

            count = 0
            for ref_data, dali_data in zip(ref_dataloader, dali_dataset):
                ref_audio, ref_audio_len, _, _ = ref_data
                ref_features, ref_features_len = ref_preprocessor(input_signal=ref_audio, length=ref_audio_len)

                dali_features, dali_features_len, _, _ = dali_data

                a = ref_features.cpu().numpy()[:, :, :ref_features_len]
                b = dali_features.cpu().numpy()[:, :, :dali_features_len]

                err = np.abs(a - b)
                assert np.mean(err) < 0.0001
                assert np.max(err) < 0.01