Ejemplo n.º 1
0
    def take_log(self, mels):
        """ Apply the log transformation to mel spectrograms.
        Args:
            mels: torch.Tensor, mel spectrograms for which to apply log.

        Returns:
            Tensor: logarithmic mel spectrogram of the mel spectrogram given as input
        """

        amp_to_db = AmplitudeToDB(stype="amplitude")
        amp_to_db.amin = 1e-5  # amin= 1e-5 as in librosa
        return amp_to_db(mels).clamp(min=-50,
                                     max=80)  # clamp to reproduce old code
def spectrogram(trace: Trace):

    trace.resample(sampling_rate)

    mel_spec = MelSpectrogram(sample_rate=sampling_rate,
                              n_mels=image_height,
                              hop_length=hop_length,
                              power=1,
                              pad_mode='reflect',
                              normalized=True)

    amplitude_to_db = AmplitudeToDB()

    # trace = trace.detrend('linear')
    # trace = trace.detrend('demean')
    trace.data = trace.data - np.mean(trace.data)
    trace = trace.taper(max_length=0.01, max_percentage=0.05)
    trace = trace.trim(starttime=trace.stats.starttime,
                       endtime=trace.stats.starttime + sequence_length_second,
                       pad=True,
                       fill_value=0)
    data = trace.data

    torch_data = torch.tensor(data).type(torch.float32)

    spec = (mel_spec(torch_data))
    spec_db = amplitude_to_db(spec.abs() + 1e-3)
    spec_db = (spec_db - spec_db.min()).numpy()
    # spec_db = (spec_db / spec_db.max()).type(torch.float32)
    return spec_db
Ejemplo n.º 3
0
    def __init__(self, 
                 output_class=264,
                 d_size=256,
                 sample_rate=32000, 
                 n_fft=2**11, 
                 top_db=80):
        
        super().__init__()
        self.mel = MelSpectrogram(sample_rate, n_fft=n_fft)
        self.norm_db = AmplitudeToDB(top_db=top_db)

        self.conv1 = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=[0, 0])
        self.bn1 = nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.relu = nn.ReLU(0.1)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
        self.dropout = nn.Dropout(0.1)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=[0, 0])
        self.bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.relu2 = nn.ReLU(0.1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
        self.dropout2 = nn.Dropout(0.1)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=[0, 0])
        self.bn3 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.relu3 = nn.ReLU(0.1)
        self.maxpool3 = nn.MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
        self.dropout3 = nn.Dropout(0.1)
        
        self.lstm = nn.LSTM(4, 128, 2, batch_first=True)
        self.dropout_lstm = nn.Dropout(0.3)
        self.bn_lstm = nn.BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        
        self.output = nn.Linear(128, output_class, bias=True)
Ejemplo n.º 4
0
    def __init__(self,
                 data_loud_mu,
                 data_loud_std,
                 sr=SAMPLE_RATE,
                 n_fft=1024,
                 hop_length=256,
                 out_size=250):
        """
        Construct an instance of LoudnessEncoder

        Args:
            data_loud_mu (torch.Tensor): The mean loudness of the dataset
            data_loud_std (torch.Tensor): The standard deviation of the
                loudness of the dataset
            sr (int, optional): Sample rate. Defaults to SAMPLE_RATE.
            n_fft (int, optional): FFT size. Defaults to 1024.
            hop_length (int, optional): FFT hop length. Defaults to 256.
            out_size (int, optional): Output length in frames. Defaults to 250.
        """
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.data_loud_mu = data_loud_mu
        self.data_loud_std = data_loud_std
        self.time_dim = int(out_size)

        self.stft = Spectrogram(n_fft=n_fft, hop_length=hop_length)
        self.db = AmplitudeToDB()
        self.centre_freqs = np.linspace(0, sr // 2, n_fft // 2 + 2)[:-1]
        self.register_buffer(
            "weightings",
            torch.tensor(librosa.core.A_weighting(self.centre_freqs)).view(
                -1, 1))
Ejemplo n.º 5
0
    def __init__(self,
                 x_shape,
                 sr=44100,
                 n_fft=1024,
                 n_mels=256,
                 win_len=256,
                 hop_len=128):
        super(ProcessMelSpectrogram, self).__init__()
        # og spectrogram process: sr 11025, n_fft 1024, n_mels 256, win_len 256, hop_len 8
        # og output shape 256, 92
        # librosa default params: sr 22050, n_fft 2048, n_mels ?, win_len 2048, hop_len 512
        # music processing: 93 ms, speech processing: 23 ms (computed by 1/(sr/hop_len))
        self.mel_s = MelSpectrogram(sample_rate=sr,
                                    n_fft=n_fft,
                                    n_mels=n_mels,
                                    win_length=win_len,
                                    hop_length=hop_len)
        self.a_to_db = AmplitudeToDB(top_db=80)

        self.x_shape = [-1] + list(x_shape)
        assert len(self.x_shape) in [2, 3, 4]

        num_samples = np.prod(self.x_shape[1:])
        spec_width = num_samples // hop_len + (num_samples % hop_len > 0)
        self.output_shape = [self.x_shape[0], 1, n_mels, spec_width]
Ejemplo n.º 6
0
    def __init__(self,
                 output_class=264,
                 d_size=256,
                 sample_rate=32000,
                 n_fft=2**11,
                 top_db=80):

        super().__init__()
        self.mel = MelSpectrogram(sample_rate, n_fft=n_fft)
        self.norm_db = AmplitudeToDB(top_db=top_db)

        self.conv1 = nn.Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(0.1)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=3)
        self.dropout = nn.Dropout(0.1)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU(0.1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=3)
        self.dropout2 = nn.Dropout(0.1)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU(0.1)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=3)
        self.dropout3 = nn.Dropout(0.1)

        self.lstm = nn.LSTM(12, 128, 2, batch_first=True)
        self.dropout_lstm = nn.Dropout(0.3)
        self.bn_lstm = nn.BatchNorm1d(128)

        self.output = nn.Linear(128, output_class)
Ejemplo n.º 7
0
def collate_fn(data, device=device):
    data = torch.stack(data)
    x = MelSpectrogram(sample_rate=sample_rate)(data)
    x = AmplitudeToDB(stype='power', top_db=80)(x)
    maxval = x.max()
    minval = x.min()
    x = (x-minval)/(maxval - minval)
    return x
 def __init__(self, arch, num_classes=10):
     super(ModelCalled, self).__init__()
     self.melspectrogram = MelSpectrogram(sample_rate=16384,   # 与FolderDataset.duration一起,使得mel图的shape=(128, 128),记住设置fmax=8000
                                          n_fft=2048,
                                          hop_length=512,
                                          f_max=8000,
                                          n_mels=128)
     self.power2db = AmplitudeToDB(stype='power')
     self.model = Models.__dict__[arch](num_classes=num_classes)
Ejemplo n.º 9
0
    def __init__(self, sample_rate, n_fft, hop_length, n_mels, top_db=None):
        super().__init__()

        self.mel_spectrogram = MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
        )
        self.amplitude_to_db = AmplitudeToDB(top_db=top_db)
Ejemplo n.º 10
0
 def __init__(self, sample_rate, n_fft, top_db, max_perc):
     super().__init__()
     self.time_stretch = TimeStretch(hop_length=None, n_freq=n_fft // 2 + 1)
     self.stft = Spectrogram(n_fft=n_fft, power=None)
     self.com_norm = ComplexNorm(power=2.)
     self.mel_specgram = MelSpectrogram(sample_rate,
                                        n_fft=n_fft,
                                        f_max=8000)
     self.AtoDB = AmplitudeToDB(top_db=top_db)
     self.dist = Uniform(1. - max_perc, 1 + max_perc)
Ejemplo n.º 11
0
    def __init__(self, norm_type, top_db=80.0):

        super(SpecNormalization, self).__init__()

        if 'db' == norm_type:
            self._norm = AmplitudeToDB(stype='power', top_db=top_db)
        elif 'whiten' == norm_type:
            self._norm = lambda x: self.z_transform(x)
        else:
            self._norm = lambda x: x
Ejemplo n.º 12
0
 def create_spectro(self, item:AudioItem):
     if self.config.mfcc: 
         mel = MFCC(sample_rate=item.sr, n_mfcc=self.config.sg_cfg.n_mfcc, melkwargs=self.config.sg_cfg.mel_args())(item.sig)
     else:
         mel = MelSpectrogram(**(self.config.sg_cfg.mel_args()))(item.sig)
         if self.config.sg_cfg.to_db_scale: 
             mel = AmplitudeToDB(top_db=self.config.sg_cfg.top_db)(mel)
     mel = mel.detach()
     if self.config.standardize: 
         mel = standardize(mel)
     if self.config.delta: 
         mel = torch.cat([torch.stack([m,torchdelta(m),torchdelta(m, order=2)]) for m in mel]) 
     return mel
Ejemplo n.º 13
0
def spectrogram_from_audio(audio: Tensor, sample_rate: int, resample_rate: int,
                           mel_filters: int, seconds: int) -> Tensor:
    resampled_audio = Resample(orig_freq=sample_rate,
                               new_freq=resample_rate)(audio)
    mono_audio = mean(resampled_audio, dim=0, keepdim=True)
    mel_transform = MelSpectrogram(sample_rate=resample_rate,
                                   n_mels=mel_filters)
    spectrogram = mel_transform(mono_audio)
    log_spectrogram = AmplitudeToDB()(spectrogram)
    original_length = log_spectrogram.shape[2]
    length = seconds * (resample_rate // mel_transform.hop_length)
    return pad(log_spectrogram, (0, length - original_length)) if original_length < length \
        else log_spectrogram[:, :, :length]
Ejemplo n.º 14
0
    def __init__(
        self,
        rate: float,
        win_length: float,
        win_step: float,
        nmels: int,
        augment: bool,
        spectro_normalization: Tuple[float, float],
    ):
        """
        Args:
            rate: the sampling rate of the waveform
            win_length: the length in second of the window for the STFT
            win_step: the length in second of the step size of the STFT window
            nmels:  the number of mel scales to consider
            augment (bool) : whether to use data augmentation or not
        """
        self.nfft = int(win_length * rate)
        self.nstep = int(win_step * rate)
        self.spectro_normalization = spectro_normalization

        ###########################
        #### START CODING HERE ####
        ###########################
        # @[email protected]_tospectro = None
        # @SOL
        modules = [
            MelSpectrogram(sample_rate=rate,
                           n_fft=self.nfft,
                           hop_length=self.nstep,
                           n_mels=nmels),
            AmplitudeToDB(),
        ]
        self.transform_tospectro = nn.Sequential(*modules)
        # SOL@

        self.transform_augment = None
        if augment:
            time_mask_duration = 0.1  # s.
            time_mask_nsamples = int(time_mask_duration / win_step)
            nmel_mask = nmels // 4

            modules = [
                FrequencyMasking(nmel_mask),
                TimeMasking(time_mask_nsamples)
            ]
            self.transform_augment = nn.Sequential(*modules)
Ejemplo n.º 15
0
 def __init__(self, 
              df, 
              sound_dir, 
              audio_sec=5,
              sample_rate=32000, 
              n_fft=2**11, 
              top_db=80
             ):
     
     self.train_df = df
     self.sound_dir = sound_dir
     self.audio_sec = audio_sec
     self.sample_rate = sample_rate
     self.mel = MelSpectrogram(sample_rate, n_fft=n_fft)
     self.norm_db = AmplitudeToDB(top_db=top_db)
     self.time_strech = TimeStretch()
     self.target_lenght = sample_rate * audio_sec
Ejemplo n.º 16
0
 def __init__(
     self,
     num_classes: int,
     hop_length: int,
     sample_rate: int,
     n_mels: int,
     n_fft: int,
     power: float,
     normalize: bool,
     use_decibels: bool,
 ) -> None:
     super().__init__()
     self.use_decibels = use_decibels
     self.melspectrogram = MelSpectrogram(
         sample_rate=sample_rate,
         n_fft=n_fft,
         hop_length=hop_length,
         n_mels=n_mels,
         power=power,
         normalized=normalize,
     )
     self.amplitude2db = AmplitudeToDB()
     self.input_bn = nn.BatchNorm2d(num_features=1)
     self.conv1 = nn.Conv2d(in_channels=1,
                            out_channels=64,
                            kernel_size=[7, 3])
     self.bn1 = nn.BatchNorm2d(num_features=64)
     self.conv2 = nn.Conv2d(in_channels=64,
                            out_channels=128,
                            kernel_size=[1, 7])
     self.bn2 = nn.BatchNorm2d(num_features=128)
     self.conv3 = nn.Conv2d(in_channels=128,
                            out_channels=256,
                            kernel_size=[1, 10])
     self.bn3 = nn.BatchNorm2d(num_features=256)
     self.conv4 = nn.Conv2d(in_channels=256,
                            out_channels=512,
                            kernel_size=[7, 1])
     self.bn4 = nn.BatchNorm2d(num_features=512)
     self.logits = nn.Linear(in_features=512, out_features=num_classes)
Ejemplo n.º 17
0
 def __init__(self, sample_rate, n_fft, top_db, max_perc):
     super().__init__()
     self.time_stretch = TimeStretch(hop_length=None, n_freq=n_fft//2+1)
     self.stft = Spectrogram(n_fft=n_fft, power=None)
     self.com_norm = ComplexNorm(power=2.)
     self.fm = FrequencyMasking(50)
     self.tm = TimeMasking(50)
     self.mel_specgram = MelSpectrogram(sample_rate, n_fft=n_fft, f_max=8000)
     self.AtoDB= AmplitudeToDB(top_db=top_db)
     self.max_perc = max_perc
     self.sample_rate = sample_rate
     self.resamples = [
             Resample(sample_rate, sample_rate*0.6),
             Resample(sample_rate, sample_rate*0.7),
             Resample(sample_rate, sample_rate*0.8),
             Resample(sample_rate, sample_rate*0.9),
             Resample(sample_rate, sample_rate*1),
             Resample(sample_rate, sample_rate*1.1),
             Resample(sample_rate, sample_rate*1.2),
             Resample(sample_rate, sample_rate*1.3),
             Resample(sample_rate, sample_rate*1.4)
         ]
Ejemplo n.º 18
0
 def create_spectro(self, item:AudioItem):
     if self.config.mfcc: 
         mel = MFCC(sample_rate=item.sr, n_mfcc=self.config.sg_cfg.n_mfcc, melkwargs=self.config.sg_cfg.mel_args())(item.sig)
     else:
         if self.config.sg_cfg.custom_spectro != None:
             mel = self.config.sg_cfg.custom_spectro(item.sig)
         else:
             if self.config.sg_cfg.n_mels > 0:
               c = self.config.sg_cfg
               mel = librosa.feature.melspectrogram(y=np.array(item.sig[0,:]), sr=item.sr, fmax=c.f_max, fmin=c.f_min, **(self.config.sg_cfg.mel_args()))
             
               mel = torch.from_numpy(mel)
               mel.unsqueeze_(0)  
             else:
               mel = Spectrogram(**(self.config.sg_cfg.spectro_args()))(item.sig)
         if self.config.sg_cfg.to_db_scale: 
             mel = AmplitudeToDB(top_db=self.config.sg_cfg.top_db)(mel)
     mel = mel.detach()
     if self.config.standardize: 
         mel = standardize(mel)
     if self.config.delta: 
         mel = torch.cat([torch.stack([m,torchdelta(m),torchdelta(m, order=2)]) for m in mel]) 
     return mel
Ejemplo n.º 19
0
 def __init__(self, **kwargs):
     self.mel_spec = MelSpectrogram(**kwargs)
     self.db_scale = AmplitudeToDB()
Ejemplo n.º 20
0
 def __init__(self, sr: int, sg_cfg: SpectrogramConfig):
     self.sg_cfg = sg_cfg
     self.spec = Spectrogram(**sg_cfg.spec_args)
     self.to_mel = MelScale(sample_rate=sr, **sg_cfg.mel_args)
     self.mfcc = MFCC(sample_rate=sr, **sg_cfg.mfcc_args)
     self.to_db = AmplitudeToDB(top_db=sg_cfg.top_db)
Ejemplo n.º 21
0
    def amplitude_to_db(self, tensor):
        """Convert from power/amplitude scale to decibel scale."""

        return AmplitudeToDB('magnitude',
                             top_db=self.negative_cutoff_db)(tensor)
Ejemplo n.º 22
0
    def __init__(
        self,
        num_classes: int,
        hop_length: int,
        sample_rate: int,
        n_mels: int,
        n_fft: int,
        power: float,
        normalize: bool,
        use_decibels: bool,
    ) -> None:
        super().__init__()
        self.use_decibels = use_decibels

        self.melspectrogram = MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
            power=power,
            normalized=normalize,
        )
        self.amplitude2db = AmplitudeToDB()
        self.input_bn = nn.BatchNorm2d(num_features=1)

        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels=16,
                               kernel_size=3,
                               padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=16)
        self.res1 = ResBlock(n_in=16, bottleneck=16, n_out=16)

        self.conv2 = nn.Conv2d(in_channels=16,
                               out_channels=32,
                               kernel_size=3,
                               padding=1)
        self.bn2 = nn.BatchNorm2d(num_features=32)
        self.res2 = ResBlock(n_in=32, bottleneck=32, n_out=32)
        self.res3 = ResBlock(n_in=32, bottleneck=32, n_out=32)

        self.conv3 = nn.Conv2d(in_channels=32,
                               out_channels=64,
                               kernel_size=3,
                               padding=1)
        self.bn3 = nn.BatchNorm2d(num_features=64)
        self.res4 = ResBlock(n_in=64, bottleneck=64, n_out=64)
        self.res5 = ResBlock(n_in=64, bottleneck=64, n_out=64)

        self.conv4 = nn.Conv2d(in_channels=64,
                               out_channels=128,
                               kernel_size=3,
                               padding=1)
        self.bn4 = nn.BatchNorm2d(num_features=128)
        self.res6 = ResBlock(n_in=128, bottleneck=128, n_out=128)
        self.res7 = ResBlock(n_in=128, bottleneck=128, n_out=128)

        self.conv5 = nn.Conv2d(in_channels=128,
                               out_channels=256,
                               kernel_size=3,
                               padding=1)
        self.bn5 = nn.BatchNorm2d(num_features=256)

        self.logits = nn.Linear(in_features=256, out_features=num_classes)
Ejemplo n.º 23
0
 def __init__(self):
     self.amplitude_to_DB = AmplitudeToDB('power', 80)
Ejemplo n.º 24
0
from torch.nn import Sequential
from torch.nn import Module
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB

from typing import Tuple

commun_transforms = Sequential(
    MelSpectrogram(sample_rate=44100, n_fft=2048, hop_length=512, n_mels=64),
    AmplitudeToDB(),
)


def supervised() -> Tuple[Module, Module]:
    train_transform = commun_transforms
    val_transform = commun_transforms

    return train_transform, val_transform


def dct() -> Tuple[Module, Module]:
    return supervised()


def dct_uniloss() -> Tuple[Module, Module]:
    return supervised()


def dct_aug4adv() -> Tuple[Module, Module]:
    raise NotImplementedError

Ejemplo n.º 25
0
def ex_waveform_spectro():
    dataset = load_dataset("train",
                           _DEFAULT_COMMONVOICE_ROOT,
                           _DEFAULT_COMMONVOICE_VERSION)

    # Take one of the waveforms 
    idx = 10
    waveform, rate, dictionary = dataset[idx]
    n_begin = rate  # 1 s.
    n_end = 3*rate  # 2 s.
    waveform = waveform[:, n_begin:n_end]  # B, T

    nfft = int(_DEFAULT_WIN_LENGTH * 1e-3 * _DEFAULT_RATE)
    # nmels = _DEFAULT_NUM_MELS
    nstep = int(_DEFAULT_WIN_STEP * 1e-3 * _DEFAULT_RATE)
    trans_spectro = nn.Sequential(
        Spectrogram(n_fft=nfft,
                    hop_length=nstep),
        AmplitudeToDB()
    )
    spectro = trans_spectro(waveform)  # B, n_mels, T

    trans_mel_spectro = WaveformProcessor(rate=rate,
                                          win_length=_DEFAULT_WIN_LENGTH*1e-3,
                                          win_step=_DEFAULT_WIN_STEP*1e-3,
                                          nmels=_DEFAULT_NUM_MELS,
                                          augment=False,
                                          spectro_normalization=None)
    mel_spectro = trans_mel_spectro(waveform.transpose(0, 1))  # T, B, n_mels
    plot_spectro(mel_spectro[:, 0, :], [],
                 _DEFAULT_WIN_STEP*1e-3,
                 CharMap())

    fig, axes = plt.subplots(nrows=1,ncols=3, figsize=(15, 3))

    ax = axes[0]
    ax.plot( [i/rate for i in range(n_begin, n_end)], waveform[0])
    ax.set_xlabel('Time (s.)')
    ax.set_ylabel('Amplitude')
    ax.set_title('Waveform')

    ax = axes[1]
    im = ax.imshow(spectro[0],
                   extent=[n_begin/rate, n_end/rate,
                           0, spectro.shape[1]],
                   aspect='auto',
                   cmap='magma',
                   origin='lower')
    ax.set_ylabel('Frequency bins')
    ax.set_xlabel('TIme (s.)')
    ax.set_title("Spectrogram (dB)")
    fig.colorbar(im, ax=ax)

    ax = axes[2]
    im = ax.imshow(mel_spectro[:, 0, :].T,
                   extent=[n_begin/rate, n_end/rate,
                           0, mel_spectro.shape[0]],
                   aspect='auto',
                   cmap='magma',
                   origin='lower')
    ax.set_ylabel('Mel scales')
    ax.set_xlabel('TIme (s.)')
    ax.set_title("Mel-Spectrogram (dB)")
    fig.colorbar(im, ax=ax)

    plt.tight_layout()
    plt.savefig("waveform_to_spectro.png")
    plt.show()
Ejemplo n.º 26
0
    def __init__(self, train_loader, test_loader, valid_loader, general_args):
        # Device
        self.device = ('cuda' if torch.cuda.is_available() else 'cpu')

        # Data generators
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader

        # Iterators to cycle over the datasets
        self.train_loader_iter = cycle(iter(self.train_loader))
        self.valid_loader_iter = cycle(iter(self.valid_loader))
        self.test_loader_iter = cycle(iter(self.test_loader))

        # Epoch counter
        self.epoch = 0

        # Stored losses
        self.train_losses = {
            'time_l2': [],
            'freq_l2': [],
            'autoencoder_l2': [],
            'generator_adversarial': [],
            'discriminator_adversarial': {
                'real': [],
                'fake': []
            }
        }
        self.test_losses = {
            'time_l2': [],
            'freq_l2': [],
            'autoencoder_l2': [],
            'generator_adversarial': [],
            'discriminator_adversarial': {
                'real': [],
                'fake': []
            }
        }
        self.valid_losses = {
            'time_l2': [],
            'freq_l2': [],
            'autoencoder_l2': [],
            'generator_adversarial': [],
            'discriminator_adversarial': {
                'real': [],
                'fake': []
            }
        }

        # Time to frequency converter
        self.spectrogram = Spectrogram(normalized=True,
                                       n_fft=512,
                                       hop_length=128).to(self.device)
        self.amplitude_to_db = AmplitudeToDB()

        # Boolean indicting if auto-encoder or generator
        self.is_autoencoder = False

        # Boolean indicating if the model needs to be saved
        self.need_saving = True

        # Set the pseudo-epochs
        self.train_batches_per_epoch = general_args.train_batches_per_epoch
        self.test_batches_per_epoch = general_args.test_batches_per_epoch
        self.valid_batches_per_epoch = general_args.valid_batches_per_epoch