import numpy as np
import torch
import ST

J = 8
L = 4
M = 512
N = 512

filter_set = ST.FiltersSet(M, N, J, L)

# generate and save morlet filter bank. "single" means single precision
save_dir = '#####'
filter_set.generate_morlet(if_save=True, save_dir=save_dir, precision='single')

# load filter bank
filters_set = np.load(
    save_dir + 'filters_set_M' + str(M) + 'N' + str(N) + 
    'J' + str(J) + 'L' + str(L) + '_single.npy',
    allow_pickle=True
)[0]['filters_set']

# define ST calculator
ST_calculator = ST.ST_2D(filters_set, J, L, device='gpu')

############ DEFINE DATA ARRAY #########
data = np.empty((30, M, N), dtype=np.float32)

################## ST ##################
# input data should be a numpy array of images with dimensions (N_image, M, N)
# output are torch tensors with assigned computing device, e.g., cuda() or cpu