def test_unpad_signal(self):
        '''FIXME: add tests on python only (not comparing with matlab) to test sizes, other stuff
        - test if input argument does not change upon function call
        '''
        # generate instance. for testing this function for this case, arguments do not matter
        scat = scu.ScatNet(2**6, 2**5)
        # calculate fields using python function using parameters retrieved from matlab test data file names
        matlab_fun = 'unpad_signal'
        test_files = glob.glob(TEST_DATA_FILEPATH + matlab_fun + '*.mat')
        regex = matlab_fun + '\{([0-9]+)\}' * 5 + '.mat'
        for test_file in test_files:
            match = re.search(regex, os.path.basename(test_file))
            data_len = int(match.group(1))
            n_data = int(match.group(2))
            res = int(match.group(3))
            orig_len = int(match.group(4))
            center = bool(int(match.group(5)))

            ref_results_file = h5py.File(test_file)
            data_in_ref = np.array(ref_results_file['data_in'])
            data_out_ref = np.array(ref_results_file['data_out'])
            data_in_ref_orig = np.copy(data_in_ref)

            data_out = scat._unpad_signal(data_in_ref, res=res, orig_len=orig_len, center=center)
            self.assertTrue(np.isclose(data_out, data_out_ref, rtol=1e-5, atol=1e-8).all())
            # check if input array does not change upon function call
            self.assertTrue(np.isclose(data_in_ref, data_in_ref_orig, rtol=1e-5, atol=1e-8).all())               
    def test_morletify(self):
        '''FIXME: add tests on python only (not comparing with matlab) to test sizes, other stuff
        NOTE: when reading arrays with complex numbers from matlab, the format for each number becomes a tuple.
        to avoid this, I saved the real and imaginary part separately and compare the results for real and imaginary
        separately.
        - test if input argument does not change upon function call
        '''
        # generate instance. for testing this function, arguments do not matter
        scat = scu.ScatNet(2**6, 2**5)
        # calculate fields using python function using parameters retrieved from matlab test data file names
        matlab_fun = 'morletify'
        test_files = glob.glob(TEST_DATA_FILEPATH + matlab_fun + '*.mat')
        # filt argument given as nparray
        regex = matlab_fun + '\{([0-9]+)\}' + '\{(-?[0-9]+\.?[0-9]*)\}' * 3 + '.mat' 
        for test_file in test_files:
            match = re.search(regex, os.path.basename(test_file))
            if match is None:
                continue

            N = int(match.group(1))
            xi = float(match.group(2))
            sigma = float(match.group(3))
            psi_sigma = float(match.group(4))

            ref_results_file = h5py.File(test_file)
            # the following arrays have shape (1, len)
            data_in_ref = ref_results_file['data_in'][0]
            data_out_ref = ref_results_file['data_out'][0]
            data_in_ref_orig = np.copy(data_in_ref) # deepcopy
            data_out = scat._morletify(data_in_ref, psi_sigma)

            self.assertTrue(np.isclose(data_out, data_out_ref, rtol=1e-5, atol=1e-8).all())
            # check if input array does not change upon function call
            self.assertTrue(np.isclose(data_in_ref_orig, data_in_ref, rtol=1e-5, atol=1e-8).all())
    def test_conv_sub_1d_fourier_truncated(self):
        '''FIXME: add tests on python only (not comparing with matlab) to test sizes, other stuff
        NOTE: when reading arrays with complex numbers from matlab, the format for each number becomes a tuple.
        to avoid this, I saved the real and imaginary part separately and compare the results for real and imaginary
        separately.
        - test if input argument does not change upon function call
        '''
        # generate instance. for testing this function for this case, arguments do not matter except that 
        # filter_format should be fourier_truncated
        scat = scu.ScatNet(2**6, 2**5, filter_format='fourier_truncated')
        # calculate fields using python function using parameters retrieved from matlab test data file names
        matlab_fun = 'conv_sub_1d'
        test_files = glob.glob(TEST_DATA_FILEPATH + matlab_fun + '*.mat')
        # filt argument given as nparray
        regex = matlab_fun + '\{([0-9]+)\}' * 2 + '\{fourier_truncated\}'  + '\{([0-9]+)\}' * 3 + '\{(-?[0-9]+)\}' + '\{([0-9]+\.?[0-9]*)\}' + '.mat' 
        for test_file in test_files:
            match = re.search(regex, os.path.basename(test_file))
            if match is None:
                continue

            data_len = int(match.group(1))
            n_data = int(match.group(2))
            filt_len = int(match.group(3))
            ds = int(match.group(4))
            filter_len = int(match.group(5))
            start = int(match.group(6))
            thresh = float(match.group(7))

            ref_results_file = h5py.File(test_file)
            data_in_ref_real = np.array(ref_results_file['data_in_real'])
            data_in_ref_imag = np.array(ref_results_file['data_in_imag'])
            data_in_ref = data_in_ref_real + data_in_ref_imag * 1j
            data_in_ref_orig = np.copy(data_in_ref)            
            # coef_real and coef_imag have shapes (1, filt_len)
            coef_in_ref_real = np.array(ref_results_file['coef_real'])[0]
            coef_in_ref_imag = np.array(ref_results_file['coef_imag'])[0]
            coef_in_ref = coef_in_ref_real + coef_in_ref_imag * 1j

            # reconstruct filt['coef']
            filt = {'filter_len':filter_len, 'start':start - 1, 'coef':coef_in_ref}
            filt_orig = copy.deepcopy(filt)
            data_out_ref_real = np.array(ref_results_file['data_out_real'])
            data_out_ref_imag = np.array(ref_results_file['data_out_imag'])

            data_out = scat._conv_sub_1d(data_in_ref, filt, ds)
            data_out_real = np.real(data_out)
            data_out_imag = np.imag(data_out)

            self.assertTrue(np.isclose(data_out_real, data_out_ref_real, rtol=1e-5, atol=1e-8).all())
            self.assertTrue(np.isclose(data_out_imag, data_out_ref_imag, rtol=1e-5, atol=1e-8).all())
            # check if input array does not change upon function call
            self.assertTrue(np.isclose(data_in_ref_orig, data_in_ref, rtol=1e-5, atol=1e-8).all())
            for key in ['filter_len', 'start']:
                self.assertEqual(filt[key], filt_orig[key])
            self.assertTrue(np.isclose(filt['coef'], filt_orig['coef'], rtol=1e-5, atol=1e-8).all())
    def test_conv_sub_1d_filt_array(self):
        '''FIXME: add tests on python only (not comparing with matlab) to test sizes, other stuff
        NOTE: when reading arrays with complex numbers from matlab, the format for each number becomes a tuple.
        to avoid this, I saved the real and imaginary part separately and compare the results for real and imaginary
        separately.
        - test if input argument does not change upon function call
        '''
        # generate instance. for testing this function for this case, arguments do not matter except that 
        # filter_format should be fourier
        scat = scu.ScatNet(2**6, 2**5, filter_format='fourier')

        # calculate fields using python function using parameters retrieved from matlab test data file names
        matlab_fun = 'conv_sub_1d'
        test_files = glob.glob(TEST_DATA_FILEPATH + matlab_fun + '*.mat')
        # filt argument given as nparray
        regex = matlab_fun + '\{([0-9]+)\}' * 4 + '.mat' 
        for test_file in test_files:
            match = re.search(regex, os.path.basename(test_file))
            if match is None:
                continue
            data_len = int(match.group(1))
            n_data = int(match.group(2))
            filt_len = int(match.group(3))
            ds = int(match.group(4))

            ref_results_file = h5py.File(test_file)
            data_in_ref_real = np.array(ref_results_file['data_in_real'])
            data_in_ref_imag = np.array(ref_results_file['data_in_imag'])
            # ref_results_file['filt_in'] has size (1, filt_len)
            filt_in_ref_real = np.array(ref_results_file['filt_in_real'])[0]
            filt_in_ref_imag = np.array(ref_results_file['filt_in_imag'])[0]
            data_out_ref_real = np.array(ref_results_file['data_out_real'])
            data_out_ref_imag = np.array(ref_results_file['data_out_imag'])

            data_in_ref = data_in_ref_real + data_in_ref_imag * 1j
            filt_in_ref = filt_in_ref_real + filt_in_ref_imag * 1j
            data_in_ref_orig = np.copy(data_in_ref)
            filt_in_ref_orig = np.copy(filt_in_ref)
            data_out = scat._conv_sub_1d(data_in_ref, filt_in_ref, ds)
            data_out_real = np.real(data_out)
            data_out_imag = np.imag(data_out)

            self.assertTrue(np.isclose(data_out_real, data_out_ref_real, rtol=1e-5, atol=1e-8).all())
            self.assertTrue(np.isclose(data_out_imag, data_out_ref_imag, rtol=1e-5, atol=1e-8).all())
            # check if input array does not change upon function call
            self.assertTrue(np.isclose(filt_in_ref, filt_in_ref_orig, rtol=1e-5, atol=1e-8).all())
            self.assertTrue(np.isclose(data_in_ref, data_in_ref_orig, rtol=1e-5, atol=1e-8).all())
    def test_truncate_filter(self):
        '''FIXME: add tests on python only (not comparing with matlab) to test sizes, other stuff
        NOTE: when reading arrays with complex numbers from matlab, the format for each number becomes a tuple.
        to avoid this, I saved the real and imaginary part separately and compare the results for real and imaginary
        separately.
        - test if input argument does not change upon function call
        '''
        # generate instance. for testing this function, arguments do not matter
        scat = scu.ScatNet(2**6, 2**5)
        # calculate fields using python function using parameters retrieved from matlab test data file names
        matlab_fun = 'truncate_filter'
        test_files = glob.glob(TEST_DATA_FILEPATH + matlab_fun + '*.mat')
        # filt argument given as nparray
        regex = matlab_fun + '\{([0-9]+)\}' + '\{([0-9]+\.?[0-9]*)\}' + '.mat' 
        for test_file in test_files:
            match = re.search(regex, os.path.basename(test_file))
            if match is None:
                continue

            filt_len = int(match.group(1))
            thresh = float(match.group(2))

            ref_results_file = h5py.File(test_file)
            start_ref = ref_results_file['start']
            N_ref = ref_results_file['N']
            recenter_ref = ref_results_file['recenter']
            type_ref = ref_results_file['type']
            # the following arrays have shapes (1, coef_len)
            coef_in_ref_real = np.array(ref_results_file['coef_in_real'])[0]
            coef_in_ref_imag = np.array(ref_results_file['coef_in_imag'])[0]
            coef_out_ref_real = np.array(ref_results_file['coef_out_real'])[0]
            coef_out_ref_imag = np.array(ref_results_file['coef_out_imag'])[0]

            coef_in_ref = coef_in_ref_real + coef_in_ref_imag * 1j
            coef_in_ref_orig = np.copy(coef_in_ref)

            filt = scat._truncate_filter(coef_in_ref, thresh)

            self.assertTrue(np.isclose(np.real(filt['coef']), coef_out_ref_real, rtol=1e-5, atol=1e-8).all())
            self.assertTrue(np.isclose(np.imag(filt['coef']), coef_out_ref_imag, rtol=1e-5, atol=1e-8).all())
            self.assertEqual(filt['start'] + 1, int(start_ref[0,0]))
            self.assertEqual(filt['filter_len'], N_ref[0,0])
            # check if input array does not change upon function call
            self.assertTrue(np.isclose(coef_in_ref_orig, coef_in_ref, rtol=1e-5, atol=1e-8).all())
    def test_periodize_filter(self):
        '''FIXME: add tests on python only (not comparing with matlab) to test sizes, other stuff
        - test if input argument does not change upon function call
        '''
        # generate instance. for testing this function for this case, arguments do not matter
        scat = scu.ScatNet(2**6, 2**5)
        # calculate fields using python function using parameters retrieved from matlab test data file names
        matlab_fun = 'periodize_filter'
        test_files = glob.glob(TEST_DATA_FILEPATH + matlab_fun + '*.mat')
        regex = matlab_fun + '\{([0-9]+)\}' + '.mat'
        for test_file in test_files:
            ref_results_file = h5py.File(test_file)
            # the values of ref_results_file have both shape (1, filter_len)
            filter_in_ref = np.array(ref_results_file['filter_f'])[0]
            coef_out_ref = np.array(ref_results_file['coef_concat'])[0]
            filter_in_ref_orig = np.copy(filter_in_ref)

            coef_out = np.concatenate(scat._periodize_filter(filter_in_ref)['coef'], axis=0)
            # coef_out has shape (filter_len,)
            self.assertTrue(np.isclose(coef_out, coef_out_ref, rtol=1e-5, atol=1e-8).all())
            # check if input array does not change upon function call
            self.assertTrue(np.isclose(filter_in_ref, filter_in_ref_orig, rtol=1e-5, atol=1e-8).all())            
from mpl_toolkits.mplot3d import Axes3D

plt.style.use('dark_background')
fontsize_title = 18
fontsize_label = 14

# scatnet parameters
data_len = 2**11
avg_len = 2**8
n_data = 200
dt = 0.001
n_filter_octave = [1,1]
sim_type = 'tbd' # 'brw', 'psn', 'obd', 'tbd'

n_decim = 3 # precision for external parameters
scat = scu.ScatNet(data_len, avg_len, n_filter_octave=n_filter_octave)

# simulate brownian
if sim_type == 'brw':
    diff_coefs_brw = np.arange(4,8,1)
    samples = siu.sim_brownian(data_len, diff_coefs_brw, dt=dt, n_data=n_data)
    traj_brw = samples['data']
    scat_brw = scu.ScatNet(data_len, avg_len, n_filter_octave=n_filter_octave)
    S_brw = scat.transform(traj_brw)
    S_brw_log = scu.log_scat(S_brw)
    S_brw_log = scu.stack_scat(S_brw_log)
    S_brw_log_mean = S_brw_log.mean(axis=-1) # average along time axis
    S_brw_log_mean = np.reshape(S_brw_log_mean, (S_brw_log_mean.shape[0], -1))

    diff_coefs_brw_str = np.round(diff_coefs_brw, n_decim).astype(str)
    labels_brw = np.repeat(diff_coefs_brw_str[:, np.newaxis], n_data, axis=-1)
    def test_morlet_freq_1d(self):
        '''
        - test if xi_psi, bw_psi are both type list.
        - test if bw_phi and all elements in bw_psi elements are positive. 
        NOTE: xi_psi can have negative elements. The stepsize in xi_psi in the linearly spaced
        spectrum can in theory be negative if sigma_phi is small although for parameters that
        construct phi this does not happen. However, even though the stepsize is positive for normal
        inputs, the number of steps taken linearly towards the negative frequency regime can result
        in negative values of center frequencies
        
        REVIEW: confirm whether the parameters that result in negative center frequencies are feasible
        If not, no need to test for cases having negative center frequencies  

        - test if xi_psi, bw_psi have length n_filter_log + n_filter_lin, n_filter_log + n_filter_lin + 1, respectively
        - test if filt_opt does not change upon function call
        FIXME: add test case where output lists have length 0(?) or 1
        '''
        # generate instance. for testing this function for this case, arguments do not matter
        scat = scu.ScatNet(2**6, 2**5)
        filt_opt = {'xi_psi':0.5, 'sigma_psi':0.4, 'sigma_phi':0.5, 'n_filter_log':11,
            'n_filter_octave':8, 'n_filter_lin':5}
        # retain a copy of filt_opt to confirm no change upon function call
        filt_opt_cp = filt_opt.copy() 
        xi_psi, bw_psi, bw_phi = scat._morlet_freq_1d(filt_opt)
        self.assertIsInstance(xi_psi, np.ndarray)
        self.assertIsInstance(bw_psi, np.ndarray)
        # self.assertTrue(all([xi > 0 for xi in xi_psi]))
        self.assertTrue(all([bw > 0 for bw in bw_psi]))
        self.assertTrue(bw_phi > 0)
        self.assertEqual(len(xi_psi), filt_opt['n_filter_log'] + filt_opt['n_filter_lin'])
        self.assertEqual(len(bw_psi), filt_opt['n_filter_log'] + filt_opt['n_filter_lin'] + 1)
        self.assertEqual(filt_opt, filt_opt_cp)

        # calculate fields using python function using parameters retrieved from matlab test data file names
        matlab_fun = 'morlet_freq_1d'
        test_files = glob.glob(TEST_DATA_FILEPATH + matlab_fun + '*.mat')
        regex = matlab_fun + '\{([0-9]+\.?[0-9]*)\}' * 7 + '.mat'
        for test_file in test_files:
            match = re.search(regex, os.path.basename(test_file))
            n_filter_octave = int(match.group(1))
            n_filter_log = int(match.group(2))
            n_filter_lin = int(match.group(3))
            xi_psi = float(match.group(4))
            sigma_psi = float(match.group(5))
            sigma_phi = float(match.group(6))
            phi_dirac = bool(int(match.group(7))) 
            # in the python version phi_dirac is always assumed to be false and therefore only
            # phi_dirac = false test cases are generated in the function that creates test data

            options = {'n_filter_octave':n_filter_octave, 'n_filter_log':n_filter_log, 'n_filter_lin':n_filter_lin,
                'xi_psi':xi_psi, 'sigma_psi':sigma_psi, 'sigma_phi':sigma_phi}
            options_orig = copy.deepcopy(options)
            xi_psi, bw_psi, bw_phi = scat._morlet_freq_1d(options)

            xi_psi_orig = copy.deepcopy(xi_psi)
            bw_psi_orig = copy.deepcopy(bw_psi)

            xi_psi_arr = np.array(xi_psi)
            bw_psi_arr = np.array(bw_psi)
            bw_phi = np.array(bw_phi)

            ref_results_file = h5py.File(test_file)
            xi_psi_ref = np.array(ref_results_file['xi_psi']).squeeze(axis=1)
            bw_psi_ref = np.array(ref_results_file['bw_psi']).squeeze(axis=1)
            bw_phi_ref = np.array(ref_results_file['bw_phi']).squeeze(axis=1)[0]

            # xi_psi_arr, xi_psi_ref, bw_psi_arr, bw_psi_ref are all rank 1 arrays. bw_phi, bw_phi_ref are both float type scalars
            self.assertTrue(np.isclose(xi_psi_arr, xi_psi_ref, rtol=1e-5, atol=1e-8).all())
            self.assertTrue(np.isclose(bw_psi_arr, bw_psi_ref, rtol=1e-5, atol=1e-8).all())
            self.assertTrue(np.isclose(bw_phi, bw_phi_ref, rtol=1e-5, atol=1e-8).all())
            # check if input array does not change upon function call
            self.assertEqual(options, options_orig)

            # check if output does not change upon changing input argument
            options.clear()
            self.assertTrue(np.isclose(xi_psi, xi_psi_orig, rtol=1e-5, atol=1e-8).all())
            self.assertTrue(np.isclose(bw_psi, bw_psi_orig, rtol=1e-5, atol=1e-8).all())
            # self.assertEqual(bw_psi, bw_psi_orig)

            # check if input argument does not change upon changing output
            # need to run function again
            options = copy.deepcopy(options_orig)
            xi_psi, bw_psi, bw_phi = scat._morlet_freq_1d(options)
            xi_psi += 1
            bw_psi += 1
            self.assertEqual(options, options_orig)
fontsize_label = 14
data_len = 2**12
avg_len = 2**8
disp_len = 2**7
dt = 0.001
n_data = 10
eps = 0.02

padded_len = np.ceil(data_len * (1 + eps * n_data))
t = np.arange(0, data_len) * dt
t_padded = np.arange(0, padded_len) * dt
#a_padded = np.sin(10 * t_padded) + np.sin(30 * t_padded) + np.sin(50 * t_padded)
a_padded = np.random.randn(len(t_padded),).cumsum()

f = interpolate.interp1d(t_padded, a_padded)
scat = scu.ScatNet(data_len=data_len, avg_len=avg_len)

data = [f(t / (1 + eps * idx)) for idx in range(n_data)]
data = np.stack(data, axis=0) # shaped (n_data, data_len)

fig1, ax1 = plt.subplots()

ax1.plot(data[:, :disp_len].swapaxes(0,1))
ax1.set_title('Dilated time series', fontsize=fontsize_title)
ax1.set_xlabel('Time', fontsize=fontsize_label)
ax1.set_ylabel('Position', fontsize=fontsize_label)

dilation = 1 + eps * np.arange(n_data)
S = scat.transform(data[:, np.newaxis, :])
S = scu.stack_scat(S) # (n_data, 1, n_nodes, data_scat_len)