Esempio n. 1
0
    def open(self, item) -> AudioItem:
        p = Path(item)
        if self.path is not None and str(self.path) not in str(item): p = self.path/item
        if not p.exists(): 
            raise FileNotFoundError(f"Neither '{item}' nor '{p}' could be found")
        if not str(p).lower().endswith(AUDIO_EXTENSIONS): raise Exception("Invalid audio file")

        cfg = self.config
        if cfg.use_spectro:
            cache_dir = self.path / cfg.cache_dir
            folder = md5(str(asdict(cfg))+str(asdict(cfg.sg_cfg)))
            fname = f"{md5(str(p))}-{p.name}.pt"
            image_path = cache_dir/(f"{folder}/{fname}")
            if cfg.cache and not cfg.force_cache and image_path.exists():
                mel = torch.load(image_path).squeeze()
                start, end = None, None
                if cfg.duration and cfg._processed:
                    mel, start, end = tfm_crop_time(mel, cfg._sr, cfg.duration, cfg.sg_cfg.hop)
                return AudioItem(spectro=mel, path=item, max_to_pad=cfg.max_to_pad, start=start, end=end)

        signal, samplerate = torchaudio.load(str(p))
        if(cfg._sr is not None and samplerate != cfg._sr):
            raise ValueError(f'''Multiple sample rates detected. Sample rate {samplerate} of file {str(p)} 
                                does not match config sample rate {cfg._sr} 
                                this means your dataset has multiple different sample rates, 
                                please choose one and set resample_to to that value''')

        if cfg.max_to_pad or cfg.segment_size:
            pad_len = cfg.max_to_pad if cfg.max_to_pad is not None else cfg.segment_size
            signal = PadTrim(max_len=int(pad_len/1000*samplerate))(signal)

        mel = None
        if cfg.use_spectro:
            if cfg.mfcc: mel = MFCC(sr=samplerate, n_mfcc=cfg.sg_cfg.n_mfcc, melkwargs=asdict(cfg.sg_cfg))(signal.reshape(1,-1))
            else:
                mel = MelSpectrogram(**(cfg.sg_cfg.mel_args()))(signal.reshape(1, -1))
                if cfg.sg_cfg.to_db_scale: mel = SpectrogramToDB(top_db=cfg.sg_cfg.top_db)(mel)
            mel = mel.squeeze().permute(1, 0)
            if cfg.standardize: mel = standardize(mel)
            if cfg.delta: mel = torch.stack([mel, torchdelta(mel), torchdelta(mel, order=2)]) 
            else: mel = mel.expand(3,-1,-1)
            if cfg.cache:
                os.makedirs(image_path.parent, exist_ok=True)
                torch.save(mel, image_path)
            start, end = None, None
            if cfg.duration and cfg._processed: 
                mel, start, end = tfm_crop_time(mel, cfg._sr, cfg.duration, cfg.sg_cfg.hop)
        return AudioItem(sig=signal.squeeze(), sr=samplerate, spectro=mel, path=item, start=start, end=end)
Esempio n. 2
0
    def open(self, item) -> AudioItem:
        p = Path(item)
        if self.path is not None and str(self.path) not in str(item):
            p = self.path / item
        if not p.exists():
            raise FileNotFoundError(
                f"Neither '{item}' nor '{p}' could be found")
        if not str(p).lower().endswith(AUDIO_EXTENSIONS):
            raise Exception("Invalid audio file")

        cfg = self.config
        if cfg.use_spectro:
            folder = md5(str(asdict(cfg)) + str(asdict(cfg.sg_cfg)))
            fname = f"{md5(str(p))}-{p.name}.pt"
            image_path = cfg.cache_dir / (f"{folder}/{fname}")
            if cfg.cache and not cfg.force_cache and image_path.exists():
                mel = torch.load(image_path).squeeze()
                start, end = None, None
                if cfg.duration and cfg._processed:
                    mel, start, end = tfm_crop_time(mel, cfg._sr, cfg.duration,
                                                    cfg.sg_cfg.hop,
                                                    cfg.pad_mode)
                return AudioItem(spectro=mel,
                                 path=item,
                                 max_to_pad=cfg.max_to_pad,
                                 start=start,
                                 end=end)

        sig, sr = torchaudio.load(str(p))
        if (cfg._sr is not None and sr != cfg._sr):
            raise ValueError(
                f'''Multiple sample rates detected. Sample rate {sr} of file {str(p)} 
                                does not match config sample rate {cfg._sr} 
                                this means your dataset has multiple different sample rates, 
                                please choose one and set resample_to to that value'''
            )
        if (sig.shape[0] > 1):
            if not cfg.downmix:
                warnings.warn(
                    f'''Audio file {p} has {sig.shape[0]} channels, automatically downmixing to mono, 
                                set AudioConfig.downmix=True to remove warnings'''
                )
            sig = DownmixMono(channels_first=True)(sig)
        if cfg.max_to_pad or cfg.segment_size:
            pad_len = cfg.max_to_pad if cfg.max_to_pad is not None else cfg.segment_size
            sig = tfm_padtrim_signal(sig,
                                     int(pad_len / 1000 * sr),
                                     pad_mode="zeros")

        mel = None
        if cfg.use_spectro:
            if cfg.mfcc:
                mel = MFCC(sr=sr,
                           n_mfcc=cfg.sg_cfg.n_mfcc,
                           melkwargs=cfg.sg_cfg.mel_args())(sig)
            else:
                mel = MelSpectrogram(**(cfg.sg_cfg.mel_args()))(sig)
                if cfg.sg_cfg.to_db_scale:
                    mel = SpectrogramToDB(top_db=cfg.sg_cfg.top_db)(mel)
            mel = mel.squeeze().permute(1, 0).flip(0)
            if cfg.standardize: mel = standardize(mel)
            if cfg.delta:
                mel = torch.stack(
                    [mel, torchdelta(mel),
                     torchdelta(mel, order=2)])
            else:
                mel = mel.expand(3, -1, -1)
            if cfg.cache:
                os.makedirs(image_path.parent, exist_ok=True)
                torch.save(mel, image_path)
                _record_cache_contents(cfg, [image_path])
            start, end = None, None
            if cfg.duration and cfg._processed:
                mel, start, end = tfm_crop_time(mel, cfg._sr, cfg.duration,
                                                cfg.sg_cfg.hop, cfg.pad_mode)
        return AudioItem(sig=sig.squeeze(),
                         sr=sr,
                         spectro=mel,
                         path=item,
                         start=start,
                         end=end)