def embed(self, aud):

        mel = mel_spectogram(aud)  #preprocessed audio
        part_frames = []
        for ix in range(
                0, mel.shape[1] - self.config_yml['SLIDING_WIN_SIZE'],
                int(self.config_yml['SLIDING_WIN_SIZE'] *
                    self.config_yml['SLIDING_WIN_OVERLAP'])):
            part_frame = mel[:, ix:ix + self.config_yml['SLIDING_WIN_SIZE']]
            part_frames.append(part_frame)


# changed
        frames = np.stack(part_frames)
        frames = torch.Tensor(frames).view(
            -1,
            self.config_yml['MEL_CHANNELS'],
            self.config_yml['SLIDING_WIN_SIZE'],
        ).to(device=self.device)
        model_pickle = torch.load(self.model_save_string.format(self.epoch),
                                  map_location=self.device)
        self.load_state_dict(model_pickle['model_state_dict'])
        with torch.no_grad():
            self.eval()
            embeds = self.forward(frames)  #.cpu().data.numpy()
            embeds = embeds * torch.reciprocal(
                torch.norm(embeds, dim=1, keepdim=True))
            embeds = torch.mean(embeds, dim=0)
            embeds = embeds.cpu().data.numpy()

        return embeds
Ejemplo n.º 2
0
    def embed(self, aud, group=True):

        aud_splits, mel_splits = split_audio_ixs(len(aud))
        max_aud_length = aud_splits[-1].stop

        if max_aud_length >= len(aud):
            aud = np.pad(aud, (0, max_aud_length - len(aud)), "constant")

        mel = mel_spectogram(aud).astype(np.float32).T
        mels = np.array([mel[s] for s in mel_splits])
        mels = torch.from_numpy(mels).to(self.device)
        embeds_all = []
        with torch.no_grad():
            for i in range(0, mels.shape[0], 64):
                embeds = self.forward(
                    mels[i:min(i+64, mels.shape[0]), :, :].to(self.device))
                embeds = embeds / torch.norm(embeds, dim=1, keepdim=True)
                embeds_all.append(embeds)
        embeds = torch.cat(embeds_all)

        if group:
            embeds = torch.mean(embeds, dim=0)
            embeds = embeds / torch.norm(embeds)

        embeds = embeds.cpu().data.numpy()
        # self.train()

        return embeds, (aud_splits, mel_splits)