コード例 #1
0
ファイル: utils.py プロジェクト: anatoole/WVD
def generate_sinc_filterbank(f0, f1, J, N):

    # get the center frequencies
    freqs = get_scaled_freqs(f0, f1, J + 1)

    # make it with difference and make it a variable
    freqs = np.stack([freqs[:-1], freqs[1:]], 1)
    freqs[:, 1] -= freqs[:, 0]
    freqs = T.Variable(freqs, name='c_freq')

    # parametrize the frequencies
    f0 = T.abs(freqs[:, 0])
    f1 = f0 + T.abs(freqs[:, 1])

    # sampled the bandpass filters
    time = T.linspace(-N // 2, N // 2 - 1, N)
    time_matrix = time.reshape((-1, 1))
    sincs = T.signal.sinc_bandpass(time_matrix, f0, f1)

    # apodize
    apod_filters = sincs * T.signal.hanning(N).reshape((-1, 1))

    # normalize
    normed_filters = apod_filters / T.linalg.norm(
        apod_filters, 2, 0, keepdims=True)

    filters = T.transpose(T.expand_dims(normed_filters, 1), [2, 1, 0])
    return filters, freqs
コード例 #2
0
 def RNTK_function(self):
     print(f"N, {self.N}, length, {self.length}")
     DATA = T.Placeholder((self.N, self.length), 'float32')
     RNTK,GP = self.RNTK_first(DATA[:,0])
     v, _ = T.scan(lambda a,b:self.RNTK_middle(a,b),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP]))
     RNTK_last,RNTK_avg = self.RNTK_output(v)
     f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg])
     return RNTK_last,RNTK_avg
コード例 #3
0
def RNTK_function(N,length,param):
    DATA = T.Placeholder((N, length), 'float32')
    RNTK,GP = RNTK_first(DATA[:,0], param['sigmaw'],param['sigmau'],param['sigmab'],param['sigmah'],param['L'], param['Lf'],param['sigmav'])
    v, _ = T.scan(lambda a,b:RNTK_middle(a,b,param['sigmaw'],param['sigmau'],param['sigmab'],param['L'], param['Lf'],param['sigmav']
                                         ),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP]))
    RNTK_last,RNTK_avg = RNTK_output(v, param['sigmav'],param['L'],param['Lf'],length)
    f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg])
    return f
コード例 #4
0
    def compute_q(self, DATA):
        DATAT = T.transpose(DATA)
        xz = DATAT[0]
        # print(xz)
        init = self.su * 2 * T.linalg.norm(T.expand_dims(xz, axis = 0), ord = 2, axis = 0) + self.sb**2 + self.sh**2
        # print("init", init)
        # init = self.su * 2 * T.linalg.norm(xz, ord = 2) + self.sb**2 + self.sh**2 #make this a vectorized function 

        def scan_func(prevq, MINIDATAT): #MINIDATAT shuold be a vector of lenght N
            # print(MINIDATAT)
            # the trick to this one is to use the original VT
            S = self.alg1_VT(prevq) # -> M is K3 
            # S = prevq
            # newq = self.su * 2 * T.linalg.norm(T.expand_dims(xz, axis = 0), ord = 2, axis = 0) + self.sb**2 + self.sh**2
            newq = self.sw*2 * S + self.su * 2 * T.linalg.norm(T.expand_dims(MINIDATAT, axis = 0), ord = 2, axis = 0) + self.sb**2
            # print("newq", newq)
            return newq, newq
        
        last_ema, all_ema = T.scan(scan_func, init = init, sequences = [DATAT[1:]])
        return T.concatenate([T.expand_dims(init, axis = 0), all_ema])
コード例 #5
0
    def create_func_for_diag(self):

        NEW_DATA = self.reorganize_data()
        NEW_DATA_ATTACHED = jnp.array(list(zip(NEW_DATA[:-1], NEW_DATA[1:])))
        # print(NEW_DATA_ATTACHED)

        x = self.DATA[:,0]
        X = x*x[:, None]

        boundary_condition = self.make_boundary_condition(X)
        
        # #lets create the inital kernels - should be starting with top right
        temp_K, temp_theta = self.compute_kernels(self.qt, self.qtprime, 0, self.dim_1, boundary_condition, boundary_condition, T.empty((self.N, self.N)), T.empty((self.N, self.N)))
        init_K, init_theta = self.compute_kernels(self.qt, self.qtprime, 0, self.dim_1 - 1, boundary_condition, boundary_condition, temp_K, temp_theta)

        initial_conditions = create_T_list([boundary_condition, boundary_condition, init_K, init_theta])
        # initial_conditions = create_T_list([T.empty((self.N, self.N)), T.empty((self.N, self.N)), T.empty((self.N, self.N)), T.empty((self.N, self.N)) + 2])
        
        ## prev_vals - (4,self.N,self.N) - previous phi, lambda, and the two kernel values
        ## idx - where we are on the diagonal
        def fn(prev_vals, idx, data_idxs, DATAPH, DATAPRIMEPH, qtph, qtprimeph):

            xTP = DATAPRIMEPH[data_idxs[0][0]] #N1
            xT = DATAPH[data_idxs[0][1]] #N2
            xINNER = T.inner(xT, xTP) #N1 x N2

            prev_lambda = prev_vals[0]
            prev_phi = prev_vals[1]
            prev_K = prev_vals[2]
            prev_theta = prev_vals[3]
            ## not boundary condition
            # print(qtph[data_idxs[0][1] - 1])
            # print(qtprimeph[data_idxs[0][0] - 1])
            S, D = self.alg2_VT(qtph[data_idxs[0][1] - 1], qtprimeph[data_idxs[0][0] - 1] ,prev_lambda)
            new_lambda = self.sw ** 2 * S + self.su ** 2 * xINNER + self.sb ** 2 ## took out an X
            new_phi = new_lambda + self.sw ** 2 * prev_phi * D
            # new_phi = prev_phi
            # new_lambda = prev_lambda

            #compute kernels
            S_kernel, D_kernel = self.alg2_VT(qtph[data_idxs[0][1]], qtprimeph[data_idxs[0][0]], new_lambda)
            new_K = prev_K + self.sv**2 * S_kernel#get current lamnda, get current qtph and qtprimeph
            new_theta = prev_theta + self.sv**2 * S_kernel + self.sv**2 * D_kernel * new_phi #TODO
            # ret_K = prev_K
            # ret_theta = prev_theta

            equal_check = lambda e: T.equal(idx, e)

            equal_result = sum(np.vectorize(equal_check)(self.ends_of_calced_diags)) > 0

            def true_f(k,t, qp, gph, dataph, dataprimeph, di): 
                xTP_NEXT = dataprimeph[di[1][0]]
                xT_NEXT = dataph[di[1][1]]
                xINNER_NEXT = T.inner(xT_NEXT, xTP_NEXT)
                new_bc = self.make_boundary_condition(xINNER_NEXT)
                ret_lambda = ret_phi = new_bc 

                S_bc_kernel, D_bc_kernel = self.alg2_VT(qp[di[1][1]], gph[di[1][0]], ret_lambda)
                ret_K = k + self.sv**2 * S_bc_kernel#get current lamnda, get current qtph and qtprimeph
                ret_theta = t + self.sv**2 * S_bc_kernel + self.sv**2 * D_bc_kernel * ret_phi #TODO

                return ret_lambda, ret_phi, ret_K, ret_theta
            false_f = lambda l,p,k,t: (l,p,k,t)

            ret_lambda, ret_phi, ret_K, ret_theta = T.cond(equal_result, true_f, false_f, [new_K, new_theta, qtph, qtprimeph, DATAPH, DATAPRIMEPH, data_idxs], [new_lambda, new_phi, new_K, new_theta])
            
            to_carry = create_T_list([ret_lambda, ret_phi, ret_K, ret_theta])
            # print('got poast second create list')
            
            return to_carry, np.array(())
        
        carry_ema, _ = T.scan(
            fn, init =  initial_conditions, sequences=[jnp.arange(0, sum(self.dim_lengths) - self.dim_num), NEW_DATA_ATTACHED], non_sequences=[T.transpose(self.DATA), T.transpose(self.DATAPRIME), self.qt, self.qtprime]
        )
        
        return carry_ema[2:4] ## so here, the output will be the added up kernels except for the boundary conditions