plt.ylim(-1, 1)
        plt.xlabel('loss evaluation index')


        plt.show()


    #%%
if __name__ == "__main__":
    from preprocess import Datasets
    import torch.utils.data
    from preprocess.transforms import SpectrogramTransform
    from preprocess.fusion import separate_sensors_collate
    from param import fs, duration_window, duration_overlap, spectro_batch_size

    spectrogram_transform = SpectrogramTransform(["Acc_norm", "Gyr_y"], fs, duration_window, duration_overlap,
                                                 spectro_batch_size, interpolation='linear', log_power=True, out_size=(48,48))
    collate_fn = separate_sensors_collate

    try :  # do not reload the datasets if they already exist
        train_dataset
        val_dataset

    except NameError:
        train_dataset = Datasets.SignalsDataSet(mode='train', split='balanced', comp_preprocess_first=True, transform=spectrogram_transform)
        val_dataset =   Datasets.SignalsDataSet(mode='val',   split='balanced', comp_preprocess_first=True, transform=spectrogram_transform)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn, num_workers=0, shuffle=True)
    val_dataloader   = torch.utils.data.DataLoader(val_dataset,   batch_size=64, collate_fn=collate_fn, num_workers=0)

    str_result = ""
    #                        axis = freq
    collate_concat = ConcatCollate(axis_collate='freq')
    dataloader.collate_fn = collate_concat
    X_batch, _ = next(iter(dataloader))

    signal = X_batch[0, 0, :]
    plt.subplot(3, 1, 3)
    plt.plot(signal.to(torch.device('cpu')).numpy())
    plt.title("'frequency' concat")

    # ---------------------- spectrogram ----------------------------
    spectrogram_transform = SpectrogramTransform(example_signals,
                                                 fs,
                                                 duration_window,
                                                 duration_overlap,
                                                 spectro_batch_size,
                                                 interpolation='none',
                                                 log_power=True)
    DS.transform = spectrogram_transform

    #                        axis = freq
    plt.figure()
    collate_concat = ConcatCollate(axis_collate='freq')
    dataloader.collate_fn = collate_concat

    X_batch, _ = next(iter(dataloader))

    signal = X_batch[0, 0, :, :]

    plt.subplot(1, 2, 1)