Exemple #1
0
    def RNTK_middle(self, previous,x): # line 7, alg 1
        X = x * x[:, None] # <x, x^t>
        rntk_old = previous[0]
        gp_old = previous[1]
        S_old,D_old = self.VT(gp_old[0])#.  //vv K(1,t-1)
        gp_new = T.expand_dims(self.sw ** 2 * S_old + (self.su ** 2) * X + self.sb ** 2,axis = 0) # line 8, alg 1
        if self.Lf == 0: # if none of the katers are fixed, use the standard
            rntk_new = T.expand_dims(gp_new[0] + self.sw**2*rntk_old[0]*D_old,axis = 0)
        else:
            rntk_new = T.expand_dims(gp_new[0],axis = 0) 
        
        print("gp_new 3", gp_new)
        print("rntk_new 3", rntk_new)

        for l in range(self.L-1): #line 10
            l = l+1
            S_new,D_new = self.VT(gp_new[l-1]) # l-1, t
            S_old,D_old = self.VT(gp_old[l]) # t-1, l
            gp_new = T.concatenate( [gp_new, T.expand_dims( self.sw ** 2 * S_old + self.su ** 2 * S_new +  self.sb ** 2, axis = 0)]) #line 10
            rntk_new = T.concatenate( [ rntk_new,  T.expand_dims( gp_new[l] +(self.Lf <= l)*self.sw**2*rntk_old[l]*D_old +(self.Lf <= (l-1))* self.su**2*rntk_new[l-1]*D_new  ,axis = 0)   ]  )
        S_old,D_old = self.VT(gp_new[self.L-1])
        gp_new = T.concatenate([gp_new,T.expand_dims(self.sv**2*S_old,axis = 0)]) # line 11
        rntk_new = T.concatenate([rntk_new,T.expand_dims(rntk_old[self.L]+ gp_new[self.L] + (self.Lf != self.L)*self.sv**2*rntk_new[self.L-1]*D_old,axis = 0)])
        print("gp_new 4", gp_new)
        print("rntk_new 4", rntk_new)
        return T.stack([rntk_new,gp_new]),x
Exemple #2
0
def create_fns(input, in_signs, Ds):

    cumulative_units = np.concatenate([[0], np.cumsum(Ds[:-1])])
    
    Ws = [sj.initializers.he((j, i)) for j, i in zip(Ds[1:], Ds[:-1])]
    bs = [sj.initializers.he((j,)) for j in Ds[1:]]

    A_w = [T.eye(Ds[0])]
    B_w = [T.zeros(Ds[0])]
    
    A_q = [T.eye(Ds[0])]
    B_q = [T.zeros(Ds[0])]
    
    maps = [input]
    signs = []
    masks = [T.ones(Ds[0])]
    in_masks = T.where(T.concatenate([T.ones(Ds[0]), in_signs]) > 0, 1.,
                                     0.1)

    for w, b in zip(Ws[:-1], bs[:-1]):
        
        pre_activation = T.matmul(w, maps[-1]) + b
        signs.append(T.sign(pre_activation))
        masks.append(T.where(pre_activation > 0, 1., 0.1))

        maps.append(pre_activation * masks[-1])

    maps.append(T.matmul(Ws[-1], maps[-1]) + bs[-1])

    # compute per region A and B
    for start, end, w, b, m in zip(cumulative_units[:-1],
                                   cumulative_units[1:], Ws, bs, masks):

        A_w.append(T.matmul(w * m, A_w[-1]))
        B_w.append(T.matmul(w * m, B_w[-1]) + b)

        A_q.append(T.matmul(w * in_masks[start:end], A_q[-1]))
        B_q.append(T.matmul(w * in_masks[start:end], B_q[-1]) + b)

    signs = T.concatenate(signs)
    ineq_b = T.concatenate(B_w[1:-1])
    ineq_A = T.vstack(A_w[1:-1])

    inequalities = T.hstack([ineq_b[:, None], ineq_A])
    inequalities = inequalities * signs[:, None] / T.linalg.norm(ineq_A, 2,
                                                         1, keepdims=True)

    inequalities_code = T.hstack([T.concatenate(B_q[1:-1])[:, None],
                                  T.vstack(A_q[1:-1])])
    inequalities_code = inequalities_code * in_signs[:, None]

    f = sj.function(input, outputs=[maps[-1], A_w[-1], B_w[-1],
                                    inequalities, signs])
    g = sj.function(in_signs, outputs=[A_q[-1], B_q[-1]])
    all_g = sj.function(in_signs, outputs=inequalities_code)
    h = sj.function(input, outputs=maps[-1])

    return f, g, h, all_g
Exemple #3
0
    def forward(self, input, crop_shape, deterministic, padding=0, seed=None):

        self.crop_shape = crop_shape
        # if given only a scalar
        if not hasattr(padding, "__len__"):
            self.pad_shape = [(padding, padding)] * (input.shape - 1)
        # else
        else:
            self.pad_shape = [(pad,
                               pad) if not hasattr(pad, "__len__") else pad
                              for pad in padding]

        assert len(self.pad_shape) == len(self.crop_shape)
        assert len(self.pad_shape) == (len(input.shape) - 1)

        self.start_indices = list()
        self.fixed_indices = list()
        for i, (pad, dim, crop) in enumerate(
                zip(self.pad_shape, input.shape[1:], self.crop_shape)):
            maxval = pad[0] + pad[1] + dim - crop
            assert maxval >= 0
            self.start_indices.append(
                T.random.randint(
                    minval=0,
                    maxval=maxval,
                    shape=(input.shape[0], 1),
                    dtype="int32",
                    seed=seed + i if seed is not None else seed,
                ))

            self.fixed_indices.append(
                T.ones((input.shape[0], 1), "int32") * (maxval // 2))
        self.start_indices = T.concatenate(self.start_indices, 1)
        self.fixed_indices = T.concatenate(self.fixed_indices, 1)

        dirac = T.cast(deterministic, "float32")

        # pad the input
        pinput = T.pad(input, [(0, 0)] + self.pad_shape)

        routput = T.stack(
            [
                T.dynamic_slice(pinput[n], self.start_indices[n],
                                self.crop_shape) for n in range(input.shape[0])
            ],
            0,
        )
        doutput = T.stack(
            [
                T.dynamic_slice(pinput[n], self.fixed_indices[n],
                                self.crop_shape) for n in range(input.shape[0])
            ],
            0,
        )

        return doutput * dirac + (1 - dirac) * routput
    def create_func_for_diag(self,
                             dim1idx,
                             dim2idx,
                             function=False,
                             jmode=False):
        diag = self.make_inputs(dim1idx, dim2idx, jmode=jmode)
        # print('test')

        ## prev_vals - (2,1) - previous phi and lambda values
        ## idx - where we are on the diagonal
        ## d1idx - y value of first dimension diag start
        ## d2idx - x value of second dimension diag start
        ## d1ph - max value of first dimension
        ## d2ph - max value of second dimension
        bc = self.sh**2 * self.sw**2 * T.eye(
            self.n, self.n) + (self.su**2) * self.X + self.sb**2
        single_boundary_condition = T.expand_dims(bc, axis=0)
        # single_boundary_condition = T.expand_dims(T.Variable((bc), "float32", "boundary_condition"), axis = 0)
        boundary_condition = T.concatenate(
            [single_boundary_condition,
             single_boundary_condition])  #one for phi and lambda

        def fn(prev_vals, idx, Xph):

            ## change - xph must now index the dataset instead of being passed in

            # tiprime_iter = d1idx + idx
            # ti_iter = d2idx + idx
            prev_lambda = prev_vals[0]
            prev_phi = prev_vals[1]
            ## not boundary condition
            S, D = self.VT(prev_lambda)
            new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2  ## took out an X
            new_phi = new_lambda + self.sw**2 * prev_phi * D
            lambda_expanded = T.expand_dims(new_lambda, axis=0)
            phi_expanded = T.expand_dims(new_phi, axis=0)
            to_return = T.concatenate([lambda_expanded, phi_expanded])

            # jax.lax.cond(to_return.shape == (2,10,10), lambda _: print(f'{idx}, true'), lambda _: print(f'{idx}, false'), operand = None)

            return to_return, to_return

        last_ema, all_ema = T.scan(fn,
                                   init=boundary_condition,
                                   sequences=[diag],
                                   non_sequences=[self.X])

        expanded_ema = T.concatenate(
            [T.expand_dims(boundary_condition, axis=0), all_ema])
        print(expanded_ema)
        if function:
            f = symjax.function(diag, outputs=expanded_ema)
            return f
        else:
            return expanded_ema
 def get_ends_of_diags(self, result_ema):
     # ends_of_diags = None
     # for end in self.ends_of_calced_diags:
     #     index_test = result_ema[int(end)]
     #     ends_of_diags = self.add_or_create(ends_of_diags, index_test)
     ends_of_diags = result_ema[self.ends_of_calced_diags.astype('int')]
     prepended = T.concatenate(
         [T.expand_dims(self.boundary_condition, axis=0), ends_of_diags])
     return T.concatenate(
         [prepended,
          T.expand_dims(self.boundary_condition, axis=0)])
Exemple #6
0
def RNTK_first(x,sw,su,sb,sh,L,Lf,sv):
    X = x*x[:, None]
    n = X.shape[0] #
    gp_new = T.expand_dims(sh ** 2 * sw ** 2 * T.eye(n, n) + (su ** 2) * X + sb ** 2, axis = 0)
    rntk_new = gp_new
    for l in range(L-1):
        l = l+1
        S_new,D_new = VT(gp_new[l-1])
        gp_new = T.concatenate([gp_new,T.expand_dims(sh ** 2 * sw ** 2 * T.eye(n, n) + su**2 * S_new + sb**2,axis = 0)])
        rntk_new = T.concatenate([rntk_new,T.expand_dims(gp_new[l] + (Lf <= (l-1))*su**2*rntk_new[l-1]*D_new,axis = 0)])
    S_old,D_old = VT(gp_new[L-1])
    gp_new = T.concatenate([gp_new,T.expand_dims(sv**2*S_old,axis = 0)])
    rntk_new = T.concatenate([rntk_new,T.expand_dims(gp_new[L] + (Lf != L)*sv**2*rntk_new[L-1]*D_old,axis = 0)])
    return rntk_new, gp_new
Exemple #7
0
    def __init__(self, input, crop_shape, deterministic, padding=0, seed=None):

        # if given only a scalar
        if not hasattr(padding, "__len__"):
            pad_shape = [(padding, padding)] * (input.ndim - 1)
        # else
        else:
            pad_shape = [(pad, pad) if not hasattr(pad, "__len__") else pad
                         for pad in padding]

        assert len(pad_shape) == len(crop_shape)
        assert len(pad_shape) == input.ndim - 1

        start_indices = list()
        fixed_indices = list()
        for i, (pad, dim,
                crop) in enumerate(zip(pad_shape, input.shape[1:],
                                       crop_shape)):
            maxval = pad[0] + pad[1] + dim - crop
            start_indices.append(
                T.random.randint(
                    minval=0,
                    maxval=maxval,
                    shape=(input.shape[0], 1),
                    dtype="int32",
                    seed=seed + i if seed is not None else seed,
                ))

            fixed_indices.append(
                T.ones((input.shape[0], 1), "int32") * (maxval // 2))
        start_indices = T.concatenate(start_indices, 1)
        fixed_indices = T.concatenate(fixed_indices, 1)

        dirac = T.cast(deterministic, "float32")

        # pad the input
        pinput = T.pad(input, [(0, 0)] + pad_shape)

        routput = T.map(
            lambda x, indices: T.dynamic_slice(x, indices, crop_shape),
            sequences=[pinput, start_indices],
        )
        doutput = T.map(
            lambda x, indices: T.dynamic_slice(x, indices, crop_shape),
            sequences=[pinput, fixed_indices],
        )

        return doutput * dirac + (1 - dirac) * routput
    def compute_kernels(self, final_ema):

        diag_ends = self.get_ends_of_diags(final_ema)

        S_init, D_init = self.VT(diag_ends[0][0])
        init_Kappa = self.sv**2 * S_init
        init_Theta = init_Kappa + self.sv**2 * diag_ends[0][1] * D_init
        init_list = T.concatenate([
            T.expand_dims(init_Kappa, axis=0),
            T.expand_dims(init_Theta, axis=0)
        ])

        def map_test(gp_rntk_sum, gp_rntk):
            S, D = self.VT(gp_rntk[0])
            ret1 = self.sv**2 * S
            ret2 = ret1 + self.sv**2 * gp_rntk[1] * D
            gp_rntk_sum = T.index_add(gp_rntk_sum, 0, ret1)
            gp_rntk_sum = T.index_add(gp_rntk_sum, 1, ret2)

            return gp_rntk_sum, gp_rntk_sum

        final_K_T, inter_results = T.scan(map_test,
                                          init=init_list,
                                          sequences=[diag_ends[1:]])
        return final_K_T
Exemple #9
0
    def __init__(self,
                 input_or_shape,
                 crop_shape,
                 deterministic,
                 padding=0,
                 seed=None):

        self.init_input(input_or_shape)
        self.crop_shape = crop_shape
        # if given only a scalar
        if not hasattr(padding, '__len__'):
            self.pad_shape = [(padding, padding)] * (self.input.shape - 1)
        # else
        else:
            self.pad_shape = [(pad,
                               pad) if not hasattr(pad, '__len__') else pad
                              for pad in padding]

        assert len(self.pad_shape) == len(self.crop_shape)
        assert len(self.pad_shape) == (len(self.input.shape) - 1)

        self.deterministic = deterministic

        self.start_indices = list()
        self.fixed_indices = list()
        for i, (pad, dim, crop) in enumerate(
                zip(self.pad_shape, self.input.shape[1:], self.crop_shape)):
            maxval = pad[0] + pad[1] + dim - crop
            assert maxval >= 0
            self.start_indices.append(
                T.random.randint(minval=0,
                                 maxval=maxval,
                                 shape=(self.input.shape[0], 1),
                                 dtype='int32',
                                 seed=seed + i if seed is not None else seed))

            self.fixed_indices.append(
                T.ones((self.input.shape[0], 1), 'int32') * (maxval // 2))
        self.start_indices = T.concatenate(self.start_indices, 1)
        self.fixed_indices = T.concatenate(self.fixed_indices, 1)

        super().__init__(self.forward(self.input))
Exemple #10
0
def RNTK_middle(previous,x,sw,su,sb,L,Lf,sv):
    X = x * x[:, None]
    rntk_old = previous[0]
    gp_old = previous[1]
    S_old,D_old = VT(gp_old[0])
    gp_new = T.expand_dims(sw ** 2 * S_old + (su ** 2) * X + sb ** 2,axis = 0)
    if Lf == 0:
        rntk_new = T.expand_dims(gp_new[0] + sw**2*rntk_old[0]*D_old,axis = 0)
    else:
        rntk_new = T.expand_dims(gp_new[0],axis = 0)
    for l in range(L-1):
        l = l+1
        S_new,D_new = VT(gp_new[l-1])
        S_old,D_old = VT(gp_old[l])
        gp_new = T.concatenate( [gp_new, T.expand_dims( sw ** 2 * S_old + su ** 2 * S_new +  sb ** 2, axis = 0)])
        rntk_new = T.concatenate( [ rntk_new,  T.expand_dims( gp_new[l] +(Lf <= l)*sw**2*rntk_old[l]*D_old +(Lf <= (l-1))* su**2*rntk_new[l-1]*D_new  ,axis = 0)   ]  )
    S_old,D_old = VT(gp_new[L-1])
    gp_new = T.concatenate([gp_new,T.expand_dims(sv**2*S_old,axis = 0)])
    rntk_new = T.concatenate([rntk_new,T.expand_dims(rntk_old[L]+ gp_new[L] + (Lf != L)*sv**2*rntk_new[L-1]*D_old,axis = 0)])
    return T.stack([rntk_new,gp_new]),x
Exemple #11
0
 def RNTK_first(self,x): # alg 1, line 1
     X = x*x[:, None]
     n = X.shape[0] #       // creates a diagonal matrix of sh^2 * sw^2
     test = self.sh ** 2 * self.sw ** 2 * T.eye(n, n) + (self.su ** 2) * X + self.sb ** 2
     gp_new = T.expand_dims(test, axis = 0) # line 2, alg 1 #GP IS GAMMA, RNTK IS PHI
     rntk_new = gp_new
     print("gp_new 1", gp_new)
     print("rntk_new 1", rntk_new)
     for l in range(self.L-1): #line 3, alg 1
         l = l+1
         print("gp_new", gp_new[l-1])
         S_new,D_new = self.VT(gp_new[l-1]) 
         gp_new = T.concatenate([gp_new,T.expand_dims(self.sh ** 2 * self.sw ** 2 * T.eye(n, n) + self.su**2 * S_new + self.sb**2,axis = 0)]) #line 4, alg 1
         rntk_new = T.concatenate([rntk_new,T.expand_dims(gp_new[l] + (self.Lf <= (l-1))*self.su**2*rntk_new[l-1]*D_new,axis = 0)])
     S_old,D_old = self.VT(gp_new[self.L-1])
     gp_new = T.concatenate([gp_new,T.expand_dims(self.sv**2*S_old,axis = 0)]) #line 5, alg 1
     rntk_new = T.concatenate([rntk_new,T.expand_dims(gp_new[self.L] + (self.Lf != self.L)*self.sv**2*rntk_new[self.L-1]*D_old,axis = 0)])
     print("gp_new 2", gp_new)
     print("rntk_new 2", rntk_new)
     return rntk_new, gp_new
    def create_func_for_diag(self):

        # NEW_DATA = self.reorganize_data()

        ## change - bc should be a function that takes a passed in X?

        bc = self.sh**2 * self.sw**2 * T.eye(self.n, self.n) + (
            self.su**2) * self.X + self.sb**2  ## took out X ||
        single_boundary_condition = T.expand_dims(bc, axis=0)
        # single_boundary_condition = T.expand_dims(T.Variable((bc), "float32", "boundary_condition"), axis = 0)
        boundary_condition = T.concatenate(
            [single_boundary_condition, single_boundary_condition])
        self.boundary_condition = boundary_condition

        # self.save_vts = {}

        ## prev_vals - (2,1) - previous phi and lambda values
        ## idx - where we are on the diagonal
        def fn(prev_vals, idx, Xph):

            ## change - xph must now index the dataset instead of being passed in
            # x = Xph['indexed']
            # X = x*x[:, None]

            # tiprime_iter = d1idx + idx
            # ti_iter = d2idx + idx
            prev_lambda = prev_vals[0]
            prev_phi = prev_vals[1]
            ## not boundary condition
            S, D = self.VT(prev_lambda)
            new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2  ## took out an X
            new_phi = new_lambda + self.sw**2 * prev_phi * D
            lambda_expanded = T.expand_dims(new_lambda, axis=0)
            phi_expanded = T.expand_dims(new_phi, axis=0)

            to_return = T.concatenate([lambda_expanded, phi_expanded])

            if idx in self.ends_of_calced_diags:
                # self.save_vts[idx] = [S,D]
                return boundary_condition, to_return
            return to_return, to_return

        last_ema, all_ema = T.scan(
            fn,
            init=boundary_condition,
            sequences=[jnp.arange(0,
                                  sum(self.dim_lengths) - self.dim_num)],
            non_sequences=[self.X])
        # if fbool:

        #     return all_ema, f
        return self.compute_kernels(all_ema)
Exemple #13
0
def joint_linear_scattering(layer, deterministic, c):
    # then standard deep network

    layer.append(layers.Conv2D(T.log(layer[-1] + 0.1), 64, (32, 16)))
    layer.append(layers.BatchNormalization(layer[-1], [0, 2, 3],
                                           deterministic))
    layer.append(layers.Lambda(layer[-1], T.abs))

    N = layer[-1].shape[0]
    features = T.concatenate([
        layer[-1].mean(3).reshape([N, -1]),
        T.log(layer[-4] + 0.1).mean(3).reshape([N, -1])
    ], 1)
    layer.append(layers.Dropout(features, 0.1, deterministic))

    layer.append(layers.Dense(layer[-1], c))
    return layer
Exemple #14
0
        def fn(prev_vals, idx, Xph):

            ## change - xph must now index the dataset instead of being passed in

            # tiprime_iter = d1idx + idx
            # ti_iter = d2idx + idx
            prev_lambda = prev_vals[0]
            prev_phi = prev_vals[1]
            ## not boundary condition
            S, D = self.VT(prev_lambda)
            new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2  ## took out an X
            new_phi = new_lambda + self.sw**2 * prev_phi * D
            lambda_expanded = T.expand_dims(new_lambda, axis=0)
            phi_expanded = T.expand_dims(new_phi, axis=0)
            to_return = T.concatenate([lambda_expanded, phi_expanded])

            # jax.lax.cond(to_return.shape == (2,10,10), lambda _: print(f'{idx}, true'), lambda _: print(f'{idx}, false'), operand = None)

            return to_return, to_return
    def compute_q(self, DATA):
        DATAT = T.transpose(DATA)
        xz = DATAT[0]
        # print(xz)
        init = self.su * 2 * T.linalg.norm(T.expand_dims(xz, axis = 0), ord = 2, axis = 0) + self.sb**2 + self.sh**2
        # print("init", init)
        # init = self.su * 2 * T.linalg.norm(xz, ord = 2) + self.sb**2 + self.sh**2 #make this a vectorized function 

        def scan_func(prevq, MINIDATAT): #MINIDATAT shuold be a vector of lenght N
            # print(MINIDATAT)
            # the trick to this one is to use the original VT
            S = self.alg1_VT(prevq) # -> M is K3 
            # S = prevq
            # newq = self.su * 2 * T.linalg.norm(T.expand_dims(xz, axis = 0), ord = 2, axis = 0) + self.sb**2 + self.sh**2
            newq = self.sw*2 * S + self.su * 2 * T.linalg.norm(T.expand_dims(MINIDATAT, axis = 0), ord = 2, axis = 0) + self.sb**2
            # print("newq", newq)
            return newq, newq
        
        last_ema, all_ema = T.scan(scan_func, init = init, sequences = [DATAT[1:]])
        return T.concatenate([T.expand_dims(init, axis = 0), all_ema])
    def reorganize_data(self, DATA, printbool=False):
        TiPrimes, Tis = self.get_diag_indices(printbool=printbool)
        reorganized_data = None
        for diag_idx in range(1, self.dim_num - 1):
            TiP = TiPrimes[diag_idx]
            Ti = Tis[diag_idx]
            dim_len = self.dim_lengths[diag_idx]
            for diag_pos in range(1, int(dim_len)):
                # we should never see 0 here, since those are reserved for boundary conditions
                if printbool:
                    print(
                        f"taking position {TiP + diag_pos} from TiP, {Ti + diag_pos} from Ti"
                    )

                reorganized_data = self.add_or_create(
                    reorganized_data,
                    T.concatenate([
                        T.expand_dims(DATA[:, TiP + diag_pos], axis=0),
                        T.expand_dims(DATA[:, Ti + diag_pos], axis=0)
                    ]))
        return reorganized_data
        def fn(prev_vals, idx, Xph):

            ## change - xph must now index the dataset instead of being passed in
            # x = Xph['indexed']
            # X = x*x[:, None]

            # tiprime_iter = d1idx + idx
            # ti_iter = d2idx + idx
            prev_lambda = prev_vals[0]
            prev_phi = prev_vals[1]
            ## not boundary condition
            S, D = self.VT(prev_lambda)
            new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2  ## took out an X
            new_phi = new_lambda + self.sw**2 * prev_phi * D
            lambda_expanded = T.expand_dims(new_lambda, axis=0)
            phi_expanded = T.expand_dims(new_phi, axis=0)

            to_return = T.concatenate([lambda_expanded, phi_expanded])

            if idx in self.ends_of_calced_diags:
                # self.save_vts[idx] = [S,D]
                return boundary_condition, to_return
            return to_return, to_return
Exemple #18
0
def create_fns(batch_size, R, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1,
               var_x=1):

    alpha = T.Placeholder((1,), 'float32')
    x = T.Placeholder((Ds[0],), 'float32')
    X = T.Placeholder((batch_size, Ds[-1]), 'float32')

    signs = T.Placeholder((np.sum(Ds[1:-1]),), 'float32')
    SIGNS = T.Placeholder((R, np.sum(Ds[1:-1])), 'float32')
    
    m0 = T.Placeholder((batch_size, R), 'float32')
    m1 = T.Placeholder((batch_size, R, Ds[0]), 'float32')
    m2 = T.Placeholder((batch_size, R, Ds[0], Ds[0]), 'float32')

    Ws, vs = init_weights(Ds, seed, scaler)
    Ws = [T.Variable(w, name='W' + str(l)) for l, w in enumerate(Ws)]
    vs = [T.Variable(v, name='v' + str(l)) for l, v in enumerate(vs)]

    var_x = T.Variable(T.ones(Ds[-1]) * var_x)
    var_z = T.Variable(T.ones(Ds[0]))

    # create the placeholders
    Ws_ph = [T.Placeholder(w.shape, w.dtype) for w in Ws]
    vs_ph = [T.Placeholder(v.shape, v.dtype) for v in vs]
    var_x_ph = T.Placeholder(var_x.shape, var_x.dtype)

    ############################################################################
    # Compute the output of g(x)
    ############################################################################

    maps = [x]
    xsigns = []
    masks = []
    
    for w, v in zip(Ws[:-1], vs[:-1]):
        
        pre_activation = T.matmul(w, maps[-1]) + v
        xsigns.append(T.sign(pre_activation))
        masks.append(relu_mask(pre_activation, leakiness))
        maps.append(pre_activation * masks[-1])

    xsigns = T.concatenate(xsigns)
    maps.append(T.matmul(Ws[-1], maps[-1]) + vs[-1])

    ############################################################################
    # compute the masks and then the per layer affine mappings
    ############################################################################

    cumulative_units = np.cumsum([0] + Ds[1:])
    xqs = relu_mask([xsigns[None, cumulative_units[i]:cumulative_units[i + 1]]
                    for i in range(len(Ds) - 2)], leakiness)
    qs = relu_mask([signs[None, cumulative_units[i]:cumulative_units[i + 1]]
                    for i in range(len(Ds) - 2)], leakiness)
    Qs = relu_mask([SIGNS[:, cumulative_units[i]:cumulative_units[i + 1]]
                    for i in range(len(Ds) - 2)], leakiness)

    Axs, bxs = get_Abs(Ws, vs, xqs)
    Aqs, bqs = get_Abs(Ws, vs, qs)
    AQs, bQs = get_Abs(Ws, vs, Qs)

    all_bxs = T.hstack(bxs[:-1]).transpose()
    all_Axs = T.hstack(Axs[:-1])[0]

    all_bqs = T.hstack(bqs[:-1]).transpose()
    all_Aqs = T.hstack(Aqs[:-1])[0]

    x_inequalities = T.hstack([all_Axs, all_bxs]) * xsigns[:, None]
    q_inequalities = T.hstack([all_Aqs, all_bqs]) * signs[:, None]

    ############################################################################
    # loss (E-step NLL)
    ############################################################################

    Bm0 = T.einsum('nd,Nn->Nd', bQs[-1], m0)
    B2m0 = T.einsum('nd,Nn->Nd', bQs[-1] ** 2, m0)
    Am1 = T.einsum('nds,Nns->Nd', AQs[-1], m1)
    ABm1 = T.einsum('nds,nd,Nns->Nd', AQs[-1], bQs[-1], m1)
    Am2ATdiag = T.diagonal(T.einsum('nds,Nnsc,npc->Ndp', AQs[-1], m2, AQs[-1]),
                        axis1=1, axis2=2)
    xAm1Bm0 = X * (Am1 + Bm0)

    M2diag = T.diagonal(m2.sum(1), axis1=1, axis2=2)
    
    prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\
            + sum([T.mean(v**2) for v in vs[:-1]], 0.) / cov_b
    loss = - 0.5 * (T.log(var_x).sum() + T.log(var_z).sum()\
            + (M2diag / var_z).sum(1).mean() + ((X ** 2 - 2 * xAm1Bm0 + B2m0\
            + Am2ATdiag + 2 * ABm1) / var_x).sum(1).mean())

    mean_loss = - (loss + 0.5 * prior)
    adam = sj.optimizers.SGD(mean_loss, 0.001, params=Ws + vs)

    ############################################################################
    # update of var_x
    ############################################################################

    update_varx = (X ** 2 - 2 * xAm1Bm0 + B2m0 + Am2ATdiag + 2 * ABm1).mean()\
                    * T.ones(Ds[-1])
    update_varz = M2diag.mean() * T.ones(Ds[0])

    ############################################################################
    # update for biases IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX
    ############################################################################

    FQ = get_forward(Ws, Qs)
    update_vs = {}
    for i in range(len(vs)):
        
        if i < len(vs) - 1:
            # now we forward each bias to the x-space except the ith
            separated_bs = bQs[-1] - T.einsum('nds,s->nd', FQ[i], vs[i])
            # compute the residual and apply sigma
            residual = (X[:, None, :] - separated_bs) * m0[:, :, None]\
                                         - T.einsum('nds,Nns->Nnd', AQs[-1], m1)
            back_error = T.einsum('nds,nd->s', FQ[i], residual.mean(0))
            whiten = T.einsum('ndc,nds,n->cs', FQ[i] , FQ[i], m0.mean(0))\
                        + T.eye(back_error.shape[0]) / (Ds[i] * cov_b)
            update_vs[vs[i]] = T.linalg.solve(whiten, back_error)
        else:
            back_error = (X - (Am1 + Bm0) + vs[-1])
            update_vs[vs[i]] = back_error.mean(0)

    ############################################################################
    # update for slopes IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX 
    ############################################################################

    update_Ws = {}
    for i in range(len(Ws)):
        
        U = T.einsum('nds,ndc->nsc', FQ[i], FQ[i])
        if i == 0:
            V = m2.mean(0)
        else:
            V1 = T.einsum('nd,nq,Nn->ndq', bQs[i-1], bQs[i-1], m0)
            V2 = T.einsum('nds,nqc,Nnsc->ndq', AQs[i-1], AQs[i-1], m2)
            V3 = T.einsum('nds,nq,Nns->ndq', AQs[i-1], bQs[i-1], m1)
            Q = T.einsum('nd,nq->ndq', Qs[i - 1], Qs[i - 1])
            V = Q * (V1 + V2 + V3 + V3.transpose((0, 2, 1))) / batch_size

        whiten = T.stack([T.kron(U[n], V[n]) for n in range(V.shape[0])]).sum(0)
        whiten = whiten + T.eye(whiten.shape[-1]) / (Ds[i]*Ds[i+1]*cov_W)
        # compute the residual (bottom up)
        if i == len(Ws) - 1:
            bottom_up = (X[:, None, :] - vs[-1])
        else:
            if i == 0:
                residual = (X[:, None, :] - bQs[-1])
            else:
                residual = (X[:, None, :] - bQs[-1]\
                            + T.einsum('nds,ns->nd', FQ[i - 1], bQs[i-1]))
            bottom_up = T.einsum('ndc,Nnd->Nnc', FQ[i], residual)

        # compute the top down vector
        if i == 0:
            top_down = m1
        else:
            top_down = Qs[i - 1] * (T.einsum('nds,Nns->Nnd', AQs[i - 1], m1) +\
                               T.einsum('nd,Nn->Nnd', bQs[i - 1], m0))

        vector = T.einsum('Nnc,Nns->cs', bottom_up, top_down) / batch_size
        condition = T.diagonal(whiten)
        update_W = T.linalg.solve(whiten, vector.reshape(-1)).reshape(Ws[i].shape)
        update_Ws[Ws[i]] = update_W

    ############################################################################
    # create the io functions
    ############################################################################

    params = sj.function(outputs = Ws + vs + [var_x])
    ll = T.Placeholder((), 'int32')
    selector = T.one_hot(ll, len(vs))
    for i in range(len(vs)):
        update_vs[vs[i]] = ((1 - alpha) * vs[i] + alpha * update_vs[vs[i]])\
                            * selector[i] + vs[i] * (1 - selector[i])
    for i in range(len(Ws)):
        update_Ws[Ws[i]] = ((1 - alpha) * Ws[i] + alpha * update_Ws[Ws[i]])\
                            * selector[i] + Ws[i] * (1 - selector[i])

    output = {'train':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                  updates=adam.updates),
              'update_var':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                        updates = {var_x: update_varx}),
              'update_vs':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                      updates = update_vs),
              'loss':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss),
              'update_Ws':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                      updates = update_Ws),
              'signs2Ab': sj.function(signs, outputs=[Aqs[-1][0], bqs[-1][0]]),
              'signs2ineq': sj.function(signs, outputs=q_inequalities),
              'g': sj.function(x, outputs=maps[-1]),
              'input2all': sj.function(x, outputs=[maps[-1], Axs[-1][0],
                                       bxs[-1][0], x_inequalities, xsigns]),
              'get_nll': sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss),
              'assign': sj.function(*Ws_ph, *vs_ph, var_x_ph,
                                    updates=dict(zip(Ws + vs + [var_x],
                                             Ws_ph + vs_ph + [var_x_ph]))),
              'varx': sj.function(outputs=var_x),
              'prior': sj.function(outputs=prior),
              'varz': sj.function(outputs=var_z),
              'params': params,
#              'probed' : sj.function(SIGNS, X, m0, m1, m2, outputs=probed),
              'input2signs': sj.function(x, outputs=xsigns),
              'S' : Ds[0], 'D':  Ds[-1], 'R': R, 'model': 'EM', 'L':len(Ds)-1,
              'kwargs': {'batch_size': batch_size, 'Ds':Ds, 'seed':seed,
                    'leakiness':leakiness, 'lr':lr, 'scaler':scaler}}
 
    def sample(n):
        samples = []
        for i in range(n):
            samples.append(output['g'](np.random.randn(Ds[0])))
        return np.array(samples)
    
    output['sample'] = sample

    return output
    # x = DATA[:,0]
    # X = x*x[:, None]
    # n = X.shape[0]

    rntkod = RNTK(dic, DATA, DATAPRIME) #could be flipped 

    start = time.time()
    kernels_ema = rntkod.create_func_for_diag()
    diag_func = symjax.function(DATA, DATAPRIME, outputs=kernels_ema)
    if printbool:
        print("time to create symjax", time.time() - start)

    return diag_func, rntkod
    # return None, rntkod

create_T_list = lambda vals: T.concatenate([T.expand_dims(i, axis = 0) for i in vals])

class RNTK():
    def __init__(self, dic, DATA, DATAPRIME, simple = False):
        if not simple:
            self.dim_1 = dic["n_entradasTiP="]
            self.dim_2 = dic["n_entradasTi="]
            self.dim_num =self.dim_1 + self.dim_2 + 1
            self.DATA = DATA
            self.DATAPRIME = DATAPRIME
            self.N = int(dic["n_patrons1="])
        self.sw = 1.142 #sqrt 2 (1.4) - FIXED
        self.su = 0.5 #[0.1,0.5,1] - SEARCH
        self.sb = 0.1 #[0, 0.2, 0.5] - SEARCH
        self.sh = 1 #[0, 0.5, 1] - SEARCH
        self.sv = 1 #1 - FIXED
 def add_or_create(self, tlist, titem):
     if tlist is None:
         return T.expand_dims(titem, axis=0)
     else:
         return T.concatenate([tlist, T.expand_dims(titem, axis=0)])
Exemple #21
0
def create_fns(input,
               in_signs,
               Ds,
               x,
               m0,
               m1,
               m2,
               batch_in_signs,
               alpha=0.1,
               sigma=1,
               sigma_x=1,
               lr=0.0002):

    cumulative_units = np.concatenate([[0], np.cumsum(Ds[:-1])])
    BS = batch_in_signs.shape[0]
    Ws = [
        T.Variable(sj.initializers.glorot((j, i)) * sigma)
        for j, i in zip(Ds[1:], Ds[:-1])
    ]
    bs = [T.Variable(sj.initializers.he((j,)) * sigma) for j in Ds[1:-1]]\
                + [T.Variable(T.zeros((Ds[-1],)))]

    A_w = [T.eye(Ds[0])]
    B_w = [T.zeros(Ds[0])]

    A_q = [T.eye(Ds[0])]
    B_q = [T.zeros(Ds[0])]

    batch_A_q = [T.eye(Ds[0]) * T.ones((BS, 1, 1))]
    batch_B_q = [T.zeros((BS, Ds[0]))]

    maps = [input]
    signs = []
    masks = [T.ones(Ds[0])]

    in_masks = T.where(T.concatenate([T.ones(Ds[0]), in_signs]) > 0, 1., alpha)
    batch_in_masks = T.where(
        T.concatenate([T.ones((BS, Ds[0])), batch_in_signs], 1) > 0, 1., alpha)

    for w, b in zip(Ws[:-1], bs[:-1]):

        pre_activation = T.matmul(w, maps[-1]) + b
        signs.append(T.sign(pre_activation))
        masks.append(T.where(pre_activation > 0, 1., alpha))

        maps.append(pre_activation * masks[-1])

    maps.append(T.matmul(Ws[-1], maps[-1]) + bs[-1])

    # compute per region A and B
    for start, end, w, b, m in zip(cumulative_units[:-1], cumulative_units[1:],
                                   Ws, bs, masks):

        A_w.append(T.matmul(w * m, A_w[-1]))
        B_w.append(T.matmul(w * m, B_w[-1]) + b)

        A_q.append(T.matmul(w * in_masks[start:end], A_q[-1]))
        B_q.append(T.matmul(w * in_masks[start:end], B_q[-1]) + b)

        batch_A_q.append(
            T.matmul(w * batch_in_masks[:, None, start:end], batch_A_q[-1]))
        batch_B_q.append((w * batch_in_masks[:, None, start:end]\
                            * batch_B_q[-1][:, None, :]).sum(2) + b)

    batch_B_q = batch_B_q[-1]
    batch_A_q = batch_A_q[-1]

    signs = T.concatenate(signs)

    inequalities = T.hstack(
        [T.concatenate(B_w[1:-1])[:, None],
         T.vstack(A_w[1:-1])]) * signs[:, None]

    inequalities_code = T.hstack(
        [T.concatenate(B_q[1:-1])[:, None],
         T.vstack(A_q[1:-1])]) * in_signs[:, None]

    #### loss
    log_sigma2 = T.Variable(sigma_x)
    sigma2 = T.exp(log_sigma2)

    Am1 = T.einsum('qds,nqs->nqd', batch_A_q, m1)
    Bm0 = T.einsum('qd,nq->nd', batch_B_q, m0)
    B2m0 = T.einsum('nq,qd->n', m0, batch_B_q**2)
    AAm2 = T.einsum('qds,qdu,nqup->nsp', batch_A_q, batch_A_q, m2)

    inner = -(x * (Am1.sum(1) + Bm0)).sum(1) + (Am1 * batch_B_q).sum((1, 2))

    loss_2 = (x**2).sum(1) + B2m0 + T.trace(AAm2, axis1=1, axis2=2).squeeze()

    loss_z = T.trace(m2.sum(1), axis1=1, axis2=2).squeeze()

    cst = 0.5 * (Ds[0] + Ds[-1]) * T.log(2 * np.pi)

    loss = cst + 0.5 * Ds[-1] * log_sigma2 + inner / sigma2\
            + 0.5 * loss_2 / sigma2 + 0.5 * loss_z

    mean_loss = loss.mean()
    adam = sj.optimizers.NesterovMomentum(mean_loss, Ws + bs, lr, 0.9)

    train_f = sj.function(batch_in_signs,
                          x,
                          m0,
                          m1,
                          m2,
                          outputs=mean_loss,
                          updates=adam.updates)
    f = sj.function(input,
                    outputs=[maps[-1], A_w[-1], B_w[-1], inequalities, signs])
    g = sj.function(in_signs, outputs=[A_q[-1], B_q[-1]])
    all_g = sj.function(in_signs, outputs=inequalities_code)
    h = sj.function(input, outputs=maps[-1])
    return f, g, h, all_g, train_f, sigma2
Exemple #22
0
 def create_network(self, state, action):
     input = nn.relu(nn.layers.Dense(T.concatenate([state, action], 1),
                                     300))
     input = nn.relu(nn.layers.Dense(input, 300))
     input = nn.layers.Dense(input, 1)
     return input