Exemplo n.º 1
0
    def compute_kernels(self, final_ema):

        diag_ends = self.get_ends_of_diags(final_ema)

        S_init, D_init = self.VT(diag_ends[0][0])
        init_Kappa = self.sv**2 * S_init
        init_Theta = init_Kappa + self.sv**2 * diag_ends[0][1] * D_init
        init_list = T.concatenate([
            T.expand_dims(init_Kappa, axis=0),
            T.expand_dims(init_Theta, axis=0)
        ])

        def map_test(gp_rntk_sum, gp_rntk):
            S, D = self.VT(gp_rntk[0])
            ret1 = self.sv**2 * S
            ret2 = ret1 + self.sv**2 * gp_rntk[1] * D
            gp_rntk_sum = T.index_add(gp_rntk_sum, 0, ret1)
            gp_rntk_sum = T.index_add(gp_rntk_sum, 1, ret2)

            return gp_rntk_sum, gp_rntk_sum

        final_K_T, inter_results = T.scan(map_test,
                                          init=init_list,
                                          sequences=[diag_ends[1:]])
        return final_K_T
Exemplo n.º 2
0
    def create_func_for_diag(self,
                             dim1idx,
                             dim2idx,
                             function=False,
                             jmode=False):
        diag = self.make_inputs(dim1idx, dim2idx, jmode=jmode)
        # print('test')

        ## prev_vals - (2,1) - previous phi and lambda values
        ## idx - where we are on the diagonal
        ## d1idx - y value of first dimension diag start
        ## d2idx - x value of second dimension diag start
        ## d1ph - max value of first dimension
        ## d2ph - max value of second dimension
        bc = self.sh**2 * self.sw**2 * T.eye(
            self.n, self.n) + (self.su**2) * self.X + self.sb**2
        single_boundary_condition = T.expand_dims(bc, axis=0)
        # single_boundary_condition = T.expand_dims(T.Variable((bc), "float32", "boundary_condition"), axis = 0)
        boundary_condition = T.concatenate(
            [single_boundary_condition,
             single_boundary_condition])  #one for phi and lambda

        def fn(prev_vals, idx, Xph):

            ## change - xph must now index the dataset instead of being passed in

            # tiprime_iter = d1idx + idx
            # ti_iter = d2idx + idx
            prev_lambda = prev_vals[0]
            prev_phi = prev_vals[1]
            ## not boundary condition
            S, D = self.VT(prev_lambda)
            new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2  ## took out an X
            new_phi = new_lambda + self.sw**2 * prev_phi * D
            lambda_expanded = T.expand_dims(new_lambda, axis=0)
            phi_expanded = T.expand_dims(new_phi, axis=0)
            to_return = T.concatenate([lambda_expanded, phi_expanded])

            # jax.lax.cond(to_return.shape == (2,10,10), lambda _: print(f'{idx}, true'), lambda _: print(f'{idx}, false'), operand = None)

            return to_return, to_return

        last_ema, all_ema = T.scan(fn,
                                   init=boundary_condition,
                                   sequences=[diag],
                                   non_sequences=[self.X])

        expanded_ema = T.concatenate(
            [T.expand_dims(boundary_condition, axis=0), all_ema])
        print(expanded_ema)
        if function:
            f = symjax.function(diag, outputs=expanded_ema)
            return f
        else:
            return expanded_ema
Exemplo n.º 3
0
 def get_ends_of_diags(self, result_ema):
     # ends_of_diags = None
     # for end in self.ends_of_calced_diags:
     #     index_test = result_ema[int(end)]
     #     ends_of_diags = self.add_or_create(ends_of_diags, index_test)
     ends_of_diags = result_ema[self.ends_of_calced_diags.astype('int')]
     prepended = T.concatenate(
         [T.expand_dims(self.boundary_condition, axis=0), ends_of_diags])
     return T.concatenate(
         [prepended,
          T.expand_dims(self.boundary_condition, axis=0)])
Exemplo n.º 4
0
def RNTK_first(x,sw,su,sb,sh,L,Lf,sv):
    X = x*x[:, None]
    n = X.shape[0] #
    gp_new = T.expand_dims(sh ** 2 * sw ** 2 * T.eye(n, n) + (su ** 2) * X + sb ** 2, axis = 0)
    rntk_new = gp_new
    for l in range(L-1):
        l = l+1
        S_new,D_new = VT(gp_new[l-1])
        gp_new = T.concatenate([gp_new,T.expand_dims(sh ** 2 * sw ** 2 * T.eye(n, n) + su**2 * S_new + sb**2,axis = 0)])
        rntk_new = T.concatenate([rntk_new,T.expand_dims(gp_new[l] + (Lf <= (l-1))*su**2*rntk_new[l-1]*D_new,axis = 0)])
    S_old,D_old = VT(gp_new[L-1])
    gp_new = T.concatenate([gp_new,T.expand_dims(sv**2*S_old,axis = 0)])
    rntk_new = T.concatenate([rntk_new,T.expand_dims(gp_new[L] + (Lf != L)*sv**2*rntk_new[L-1]*D_old,axis = 0)])
    return rntk_new, gp_new
Exemplo n.º 5
0
def generate_sinc_filterbank(f0, f1, J, N):

    # get the center frequencies
    freqs = get_scaled_freqs(f0, f1, J + 1)

    # make it with difference and make it a variable
    freqs = np.stack([freqs[:-1], freqs[1:]], 1)
    freqs[:, 1] -= freqs[:, 0]
    freqs = T.Variable(freqs, name='c_freq')

    # parametrize the frequencies
    f0 = T.abs(freqs[:, 0])
    f1 = f0 + T.abs(freqs[:, 1])

    # sampled the bandpass filters
    time = T.linspace(-N // 2, N // 2 - 1, N)
    time_matrix = time.reshape((-1, 1))
    sincs = T.signal.sinc_bandpass(time_matrix, f0, f1)

    # apodize
    apod_filters = sincs * T.signal.hanning(N).reshape((-1, 1))

    # normalize
    normed_filters = apod_filters / T.linalg.norm(
        apod_filters, 2, 0, keepdims=True)

    filters = T.transpose(T.expand_dims(normed_filters, 1), [2, 1, 0])
    return filters, freqs
Exemplo n.º 6
0
    def RNTK_middle(self, previous,x): # line 7, alg 1
        X = x * x[:, None] # <x, x^t>
        rntk_old = previous[0]
        gp_old = previous[1]
        S_old,D_old = self.VT(gp_old[0])#.  //vv K(1,t-1)
        gp_new = T.expand_dims(self.sw ** 2 * S_old + (self.su ** 2) * X + self.sb ** 2,axis = 0) # line 8, alg 1
        if self.Lf == 0: # if none of the katers are fixed, use the standard
            rntk_new = T.expand_dims(gp_new[0] + self.sw**2*rntk_old[0]*D_old,axis = 0)
        else:
            rntk_new = T.expand_dims(gp_new[0],axis = 0) 
        
        print("gp_new 3", gp_new)
        print("rntk_new 3", rntk_new)

        for l in range(self.L-1): #line 10
            l = l+1
            S_new,D_new = self.VT(gp_new[l-1]) # l-1, t
            S_old,D_old = self.VT(gp_old[l]) # t-1, l
            gp_new = T.concatenate( [gp_new, T.expand_dims( self.sw ** 2 * S_old + self.su ** 2 * S_new +  self.sb ** 2, axis = 0)]) #line 10
            rntk_new = T.concatenate( [ rntk_new,  T.expand_dims( gp_new[l] +(self.Lf <= l)*self.sw**2*rntk_old[l]*D_old +(self.Lf <= (l-1))* self.su**2*rntk_new[l-1]*D_new  ,axis = 0)   ]  )
        S_old,D_old = self.VT(gp_new[self.L-1])
        gp_new = T.concatenate([gp_new,T.expand_dims(self.sv**2*S_old,axis = 0)]) # line 11
        rntk_new = T.concatenate([rntk_new,T.expand_dims(rntk_old[self.L]+ gp_new[self.L] + (self.Lf != self.L)*self.sv**2*rntk_new[self.L-1]*D_old,axis = 0)])
        print("gp_new 4", gp_new)
        print("rntk_new 4", rntk_new)
        return T.stack([rntk_new,gp_new]),x
Exemplo n.º 7
0
def generate_morlet_filterbank(N, J, Q):
    freqs = np.ones(J * Q, dtype='float32') * 5
    scales = 2**(np.linspace(-0.5, np.log2(2 * np.pi * np.log2(N)), J * Q))
    filters = T.signal.morlet(N,
                              s=0.1 + scales.reshape(
                                  (-1, 1)).astype('float32'),
                              w=freqs.reshape((-1, 1)).astype('float32'))
    filters_norm = filters / T.linalg.norm(filters, 2, 1, keepdims=True)
    return T.expand_dims(filters_norm, 1)
 def scan_func(prevq, MINIDATAT): #MINIDATAT shuold be a vector of lenght N
     # print(MINIDATAT)
     # the trick to this one is to use the original VT
     S = self.alg1_VT(prevq) # -> M is K3 
     # S = prevq
     # newq = self.su * 2 * T.linalg.norm(T.expand_dims(xz, axis = 0), ord = 2, axis = 0) + self.sb**2 + self.sh**2
     newq = self.sw*2 * S + self.su * 2 * T.linalg.norm(T.expand_dims(MINIDATAT, axis = 0), ord = 2, axis = 0) + self.sb**2
     # print("newq", newq)
     return newq, newq
Exemplo n.º 9
0
        def fn(prev_vals, idx, Xph):

            ## change - xph must now index the dataset instead of being passed in

            # tiprime_iter = d1idx + idx
            # ti_iter = d2idx + idx
            prev_lambda = prev_vals[0]
            prev_phi = prev_vals[1]
            ## not boundary condition
            S, D = self.VT(prev_lambda)
            new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2  ## took out an X
            new_phi = new_lambda + self.sw**2 * prev_phi * D
            lambda_expanded = T.expand_dims(new_lambda, axis=0)
            phi_expanded = T.expand_dims(new_phi, axis=0)
            to_return = T.concatenate([lambda_expanded, phi_expanded])

            # jax.lax.cond(to_return.shape == (2,10,10), lambda _: print(f'{idx}, true'), lambda _: print(f'{idx}, false'), operand = None)

            return to_return, to_return
Exemplo n.º 10
0
def generate_gaussian_filterbank(N, M, J, f0, f1, modes=1):

    # gaussian parameters
    freqs = get_scaled_freqs(f0, f1, J)
    freqs *= (J - 1) * 10

    if modes > 1:
        other_modes = np.random.randint(0, J, J * (modes - 1))
        freqs = np.concatenate([freqs, freqs[other_modes]])

    # crate the average vectors
    mu_init = np.stack([freqs, 0.1 * np.random.randn(J * modes)], 1)
    mu = T.Variable(mu_init.astype('float32'), name='mu')

    # create the covariance matrix
    cor = T.Variable(0.01 * np.random.randn(J * modes).astype('float32'),
                     name='cor')

    sigma_init = np.stack([freqs / 6, 1. + 0.01 * np.random.randn(J * modes)],
                          1)
    sigma = T.Variable(sigma_init.astype('float32'), name='sigma')

    # create the mixing coefficients
    mixing = T.Variable(np.ones((modes, 1, 1)).astype('float32'))

    # now apply our parametrization
    coeff = T.stop_gradient(T.sqrt((T.abs(sigma) + 0.1).prod(1))) * 0.95
    Id = T.eye(2)
    cov = Id * T.expand_dims((T.abs(sigma)+0.1),1) +\
            T.flip(Id, 0) * (T.tanh(cor) * coeff).reshape((-1, 1, 1))
    cov_inv = T.linalg.inv(cov)

    # get the gaussian filters
    time = T.linspace(-5, 5, M)
    freq = T.linspace(0, J * 10, N)
    x, y = T.meshgrid(time, freq)
    grid = T.stack([y.flatten(), x.flatten()], 1)
    centered = grid - T.expand_dims(mu, 1)
    # asdf
    gaussian = T.exp(-(T.matmul(centered, cov_inv)**2).sum(-1))
    norm = T.linalg.norm(gaussian, 2, 1, keepdims=True)
    gaussian_2d = T.abs(mixing) * T.reshape(gaussian / norm, (J, modes, N, M))
    return gaussian_2d.sum(1, keepdims=True), mu, cor, sigma, mixing
Exemplo n.º 11
0
 def RNTK_first(self,x): # alg 1, line 1
     X = x*x[:, None]
     n = X.shape[0] #       // creates a diagonal matrix of sh^2 * sw^2
     test = self.sh ** 2 * self.sw ** 2 * T.eye(n, n) + (self.su ** 2) * X + self.sb ** 2
     gp_new = T.expand_dims(test, axis = 0) # line 2, alg 1 #GP IS GAMMA, RNTK IS PHI
     rntk_new = gp_new
     print("gp_new 1", gp_new)
     print("rntk_new 1", rntk_new)
     for l in range(self.L-1): #line 3, alg 1
         l = l+1
         print("gp_new", gp_new[l-1])
         S_new,D_new = self.VT(gp_new[l-1]) 
         gp_new = T.concatenate([gp_new,T.expand_dims(self.sh ** 2 * self.sw ** 2 * T.eye(n, n) + self.su**2 * S_new + self.sb**2,axis = 0)]) #line 4, alg 1
         rntk_new = T.concatenate([rntk_new,T.expand_dims(gp_new[l] + (self.Lf <= (l-1))*self.su**2*rntk_new[l-1]*D_new,axis = 0)])
     S_old,D_old = self.VT(gp_new[self.L-1])
     gp_new = T.concatenate([gp_new,T.expand_dims(self.sv**2*S_old,axis = 0)]) #line 5, alg 1
     rntk_new = T.concatenate([rntk_new,T.expand_dims(gp_new[self.L] + (self.Lf != self.L)*self.sv**2*rntk_new[self.L-1]*D_old,axis = 0)])
     print("gp_new 2", gp_new)
     print("rntk_new 2", rntk_new)
     return rntk_new, gp_new
Exemplo n.º 12
0
    def compute_q(self, DATA):
        DATAT = T.transpose(DATA)
        xz = DATAT[0]
        # print(xz)
        init = self.su * 2 * T.linalg.norm(T.expand_dims(xz, axis = 0), ord = 2, axis = 0) + self.sb**2 + self.sh**2
        # print("init", init)
        # init = self.su * 2 * T.linalg.norm(xz, ord = 2) + self.sb**2 + self.sh**2 #make this a vectorized function 

        def scan_func(prevq, MINIDATAT): #MINIDATAT shuold be a vector of lenght N
            # print(MINIDATAT)
            # the trick to this one is to use the original VT
            S = self.alg1_VT(prevq) # -> M is K3 
            # S = prevq
            # newq = self.su * 2 * T.linalg.norm(T.expand_dims(xz, axis = 0), ord = 2, axis = 0) + self.sb**2 + self.sh**2
            newq = self.sw*2 * S + self.su * 2 * T.linalg.norm(T.expand_dims(MINIDATAT, axis = 0), ord = 2, axis = 0) + self.sb**2
            # print("newq", newq)
            return newq, newq
        
        last_ema, all_ema = T.scan(scan_func, init = init, sequences = [DATAT[1:]])
        return T.concatenate([T.expand_dims(init, axis = 0), all_ema])
Exemplo n.º 13
0
def generate_learnmorlet_filterbank(N, J, Q):
    freqs = T.Variable(np.ones(J * Q) * 5)
    scales = 2**(np.linspace(-0.5, np.log2(2 * np.pi * np.log2(N)), J * Q))
    scales = T.Variable(scales)

    filters = T.signal.morlet(N,
                              s=0.01 + T.abs(scales.reshape((-1, 1))),
                              w=freqs.reshape((-1, 1)))
    filters_norm = filters / T.linalg.norm(filters, 2, 1, keepdims=True)

    return T.expand_dims(filters_norm, 1), freqs, scales
Exemplo n.º 14
0
    def reorganize_data(self, DATA, printbool=False):
        TiPrimes, Tis = self.get_diag_indices(printbool=printbool)
        reorganized_data = None
        for diag_idx in range(1, self.dim_num - 1):
            TiP = TiPrimes[diag_idx]
            Ti = Tis[diag_idx]
            dim_len = self.dim_lengths[diag_idx]
            for diag_pos in range(1, int(dim_len)):
                # we should never see 0 here, since those are reserved for boundary conditions
                if printbool:
                    print(
                        f"taking position {TiP + diag_pos} from TiP, {Ti + diag_pos} from Ti"
                    )

                reorganized_data = self.add_or_create(
                    reorganized_data,
                    T.concatenate([
                        T.expand_dims(DATA[:, TiP + diag_pos], axis=0),
                        T.expand_dims(DATA[:, Ti + diag_pos], axis=0)
                    ]))
        return reorganized_data
Exemplo n.º 15
0
    def create_func_for_diag(self):

        # NEW_DATA = self.reorganize_data()

        ## change - bc should be a function that takes a passed in X?

        bc = self.sh**2 * self.sw**2 * T.eye(self.n, self.n) + (
            self.su**2) * self.X + self.sb**2  ## took out X ||
        single_boundary_condition = T.expand_dims(bc, axis=0)
        # single_boundary_condition = T.expand_dims(T.Variable((bc), "float32", "boundary_condition"), axis = 0)
        boundary_condition = T.concatenate(
            [single_boundary_condition, single_boundary_condition])
        self.boundary_condition = boundary_condition

        # self.save_vts = {}

        ## prev_vals - (2,1) - previous phi and lambda values
        ## idx - where we are on the diagonal
        def fn(prev_vals, idx, Xph):

            ## change - xph must now index the dataset instead of being passed in
            # x = Xph['indexed']
            # X = x*x[:, None]

            # tiprime_iter = d1idx + idx
            # ti_iter = d2idx + idx
            prev_lambda = prev_vals[0]
            prev_phi = prev_vals[1]
            ## not boundary condition
            S, D = self.VT(prev_lambda)
            new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2  ## took out an X
            new_phi = new_lambda + self.sw**2 * prev_phi * D
            lambda_expanded = T.expand_dims(new_lambda, axis=0)
            phi_expanded = T.expand_dims(new_phi, axis=0)

            to_return = T.concatenate([lambda_expanded, phi_expanded])

            if idx in self.ends_of_calced_diags:
                # self.save_vts[idx] = [S,D]
                return boundary_condition, to_return
            return to_return, to_return

        last_ema, all_ema = T.scan(
            fn,
            init=boundary_condition,
            sequences=[jnp.arange(0,
                                  sum(self.dim_lengths) - self.dim_num)],
            non_sequences=[self.X])
        # if fbool:

        #     return all_ema, f
        return self.compute_kernels(all_ema)
Exemplo n.º 16
0
        def fn(prev_vals, idx, Xph):

            ## change - xph must now index the dataset instead of being passed in
            # x = Xph['indexed']
            # X = x*x[:, None]

            # tiprime_iter = d1idx + idx
            # ti_iter = d2idx + idx
            prev_lambda = prev_vals[0]
            prev_phi = prev_vals[1]
            ## not boundary condition
            S, D = self.VT(prev_lambda)
            new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2  ## took out an X
            new_phi = new_lambda + self.sw**2 * prev_phi * D
            lambda_expanded = T.expand_dims(new_lambda, axis=0)
            phi_expanded = T.expand_dims(new_phi, axis=0)

            to_return = T.concatenate([lambda_expanded, phi_expanded])

            if idx in self.ends_of_calced_diags:
                # self.save_vts[idx] = [S,D]
                return boundary_condition, to_return
            return to_return, to_return
Exemplo n.º 17
0
def RNTK_middle(previous,x,sw,su,sb,L,Lf,sv):
    X = x * x[:, None]
    rntk_old = previous[0]
    gp_old = previous[1]
    S_old,D_old = VT(gp_old[0])
    gp_new = T.expand_dims(sw ** 2 * S_old + (su ** 2) * X + sb ** 2,axis = 0)
    if Lf == 0:
        rntk_new = T.expand_dims(gp_new[0] + sw**2*rntk_old[0]*D_old,axis = 0)
    else:
        rntk_new = T.expand_dims(gp_new[0],axis = 0)
    for l in range(L-1):
        l = l+1
        S_new,D_new = VT(gp_new[l-1])
        S_old,D_old = VT(gp_old[l])
        gp_new = T.concatenate( [gp_new, T.expand_dims( sw ** 2 * S_old + su ** 2 * S_new +  sb ** 2, axis = 0)])
        rntk_new = T.concatenate( [ rntk_new,  T.expand_dims( gp_new[l] +(Lf <= l)*sw**2*rntk_old[l]*D_old +(Lf <= (l-1))* su**2*rntk_new[l-1]*D_new  ,axis = 0)   ]  )
    S_old,D_old = VT(gp_new[L-1])
    gp_new = T.concatenate([gp_new,T.expand_dims(sv**2*S_old,axis = 0)])
    rntk_new = T.concatenate([rntk_new,T.expand_dims(rntk_old[L]+ gp_new[L] + (Lf != L)*sv**2*rntk_new[L-1]*D_old,axis = 0)])
    return T.stack([rntk_new,gp_new]),x
Exemplo n.º 18
0
    # x = DATA[:,0]
    # X = x*x[:, None]
    # n = X.shape[0]

    rntkod = RNTK(dic, DATA, DATAPRIME) #could be flipped 

    start = time.time()
    kernels_ema = rntkod.create_func_for_diag()
    diag_func = symjax.function(DATA, DATAPRIME, outputs=kernels_ema)
    if printbool:
        print("time to create symjax", time.time() - start)

    return diag_func, rntkod
    # return None, rntkod

create_T_list = lambda vals: T.concatenate([T.expand_dims(i, axis = 0) for i in vals])

class RNTK():
    def __init__(self, dic, DATA, DATAPRIME, simple = False):
        if not simple:
            self.dim_1 = dic["n_entradasTiP="]
            self.dim_2 = dic["n_entradasTi="]
            self.dim_num =self.dim_1 + self.dim_2 + 1
            self.DATA = DATA
            self.DATAPRIME = DATAPRIME
            self.N = int(dic["n_patrons1="])
        self.sw = 1.142 #sqrt 2 (1.4) - FIXED
        self.su = 0.5 #[0.1,0.5,1] - SEARCH
        self.sb = 0.1 #[0, 0.2, 0.5] - SEARCH
        self.sh = 1 #[0, 0.5, 1] - SEARCH
        self.sv = 1 #1 - FIXED
Exemplo n.º 19
0
def create_transform(input, args):
    input_r = input.reshape((args.BS, 1, -1))

    if args.option == 'melspec':
        layer = [
            T.signal.melspectrogram(input_r,
                                    window=args.bins,
                                    hop=args.hop,
                                    n_filter=args.J * args.Q,
                                    low_freq=3,
                                    high_freq=22050,
                                    nyquist=22050,
                                    mode='same')
        ]

    elif args.option == 'raw':
        layer = [
            layers.Conv1D(input_r,
                          strides=args.hop,
                          W_shape=(args.J * args.Q, 1, args.bins),
                          trainable_b=False,
                          pad='SAME')
        ]
        layer.append(layers.Lambda(T.expand_dims(layer[-1], 1), T.abs))

    elif args.option == 'morlet':
        filters = generate_morlet_filterbank(args.bins, args.J, args.Q)
        layer = [
            layers.Conv1D(input_r,
                          args.J * args.Q,
                          args.bins,
                          W=filters.real(),
                          trainable_W=False,
                          stride=args.hop,
                          trainable_b=False,
                          pad='SAME')
        ]
        layer.append(
            layers.Conv1D(input_r,
                          args.J * args.Q,
                          args.bins,
                          W=filters.imag(),
                          trainable_W=False,
                          stride=args.hop,
                          trainable_b=False,
                          pad='SAME'))
        layer.append(T.sqrt(layer[-1]**2 + layer[-2]**2))
        layer.append(T.expand_dims(layer[-1], 1))

    elif args.option == 'learnmorlet':
        filters, freqs, scales = generate_learnmorlet_filterbank(
            args.bins, args.J, args.Q)

        layer = [
            layers.Conv1D(input_r,
                          args.J * args.Q,
                          args.bins,
                          W=T.real(filters),
                          trainable_W=False,
                          stride=args.hop,
                          trainable_b=False,
                          pad='SAME')
        ]
        layer.append(
            layers.Conv1D(input_r,
                          args.J * args.Q,
                          args.bins,
                          W=T.imag(filters),
                          trainable_W=False,
                          stride=args.hop,
                          trainable_b=False,
                          pad='SAME'))
        layer[0].add_variable(freqs)
        layer[0].add_variable(scales)
        layer[0]._filters = filters
        layer[0]._scales = scales
        layer[0]._freqs = freqs
        layer.append(T.sqrt(layer[-1]**2 + layer[-2]**2 + 0.001))
        layer.append(T.expand_dims(layer[-1], 1))

    elif 'wvd' in args.option:
        WVD = T.signal.wvd(input_r,
                           window=args.bins * 2,
                           L=args.L * 2,
                           hop=args.hop,
                           mode='same')
        if args.option == 'wvd':
            modes = 1
        else:
            modes = 3
        filters, mu, cor, sigma, mixing = generate_gaussian_filterbank(
            args.bins, 64, args.J * args.Q, 5, 22050, modes)
        print(WVD)
        #        filters=T.random.randn((args.J * args.Q, 1, args.bins*2, 5))
        wvd = T.convNd(WVD, filters)[:, :, 0]
        print('wvd', wvd)
        layer = [layers.Identity(T.expand_dims(wvd, 1))]
        layer[-1].add_variable(mu)
        layer[-1].add_variable(cor)
        layer[-1].add_variable(sigma)
        layer[-1].add_variable(mixing)
        layer[-1]._mu = mu
        layer[-1]._cor = cor
        layer[-1]._sigma = sigma
        layer[-1]._mixing = mixing
        layer[-1]._filter = filters
        layer.append(layers.Lambda(layer[-1], T.abs))

    elif args.option == 'sinc':
        filters, freq = generate_sinc_filterbank(5, 22050, args.J * args.Q,
                                                 args.bins)
        layer = [
            layers.Conv1D(input.reshape((args.BS, 1, -1)),
                          args.J * args.Q,
                          args.bins,
                          W=filters,
                          stride=args.hop,
                          trainable_b=False,
                          trainable_W=False,
                          pad='SAME')
        ]
        layer[-1]._freq = freq
        layer[-1]._filter = filters
        layer[-1].add_variable(freq)
        layer.append(T.expand_dims(layer[-1], 1))
        layer.append(layers.Lambda(layer[-1], T.abs))
    return layer
Exemplo n.º 20
0
 def add_or_create(self, tlist, titem):
     if tlist is None:
         return T.expand_dims(titem, axis=0)
     else:
         return T.concatenate([tlist, T.expand_dims(titem, axis=0)])
    # n = X.shape[0]

    rntkod = RNTK(dic, DATA, DATAPRIME)  #could be flipped

    start = time.time()
    kernels_ema = rntkod.create_func_for_diag()
    diag_func = symjax.function(DATA, DATAPRIME, outputs=kernels_ema)
    if printbool:
        print("time to create symjax", time.time() - start)

    return diag_func, rntkod
    # return None, rntkod


create_T_list = lambda vals: T.concatenate(
    [T.expand_dims(i, axis=0) for i in vals])


class RNTK():
    def __init__(self, dic, DATA, DATAPRIME, simple=False):
        if not simple:
            self.dim_1 = dic["n_entradasTiP="]
            self.dim_2 = dic["n_entradasTi="]
            self.dim_num = self.dim_1 + self.dim_2 + 1
            self.DATA = DATA
            self.DATAPRIME = DATAPRIME
            self.N = int(dic["n_patrons1="])
        self.sw = 1.142  #sqrt 2 (1.4) - FIXED
        self.su = 0.5  #[0.1,0.5,1] - SEARCH
        self.sb = 0.1  #[0, 0.2, 0.5] - SEARCH
        self.sh = 1  #[0, 0.5, 1] - SEARCH