コード例 #1
0
ファイル: layers.py プロジェクト: SymJAX/SymJAX
    def __init__(
        self,
        sequence,
        init_h,
        units,
        W=initializers.glorot_uniform,
        H=initializers.orthogonal,
        b=T.zeros,
        trainable_W=True,
        trainable_H=True,
        trainable_b=True,
        activation=nn.sigmoid,
        only_last=False,
    ):

        W = create_variable("W",
                            W, (sequence.shape[2], units),
                            trainable=trainable_W)
        H = create_variable("H", H, (units, units), trainable=trainable_H)
        b = create_variable("b", b, (units, ), trainable=trainable_b)

        last, output = T.scan(
            lambda h, x, W, H, b: RNN.gate(h, x, W, H, b, activation),
            init=init_h,
            sequences=[sequence.transpose((1, 0, 2))],
            non_sequences=[W, H, b],
        )
        if only_last:
            return last
        else:
            return output.transpose((1, 0, 2))
コード例 #2
0
    def forward(
        self,
        sequence,
        init_h,
        units,
        W=initializers.he,
        H=initializers.he,
        b=T.zeros,
        trainable_W=True,
        trainable_H=True,
        trainable_b=True,
        activation=nn.sigmoid,
        only_last=False,
    ):

        self.create_variable("W",
                             W, (sequence.shape[2], units),
                             trainable=trainable_W)
        self.create_variable("H", H, (units, units), trainable=trainable_H)
        self.create_variable("b", b, (units), trainable=trainable_b)

        last, output = T.scan(
            lambda h, x, W, H, b: self.gate(h, x, W, H, b, activation),
            init=init_h,
            sequences=[sequence.transpose((1, 0, 2))],
            non_sequences=[self.W, self.H, self.b],
        )
        if only_last:
            return last
        else:
            return output.transpose((1, 0, 2))
コード例 #3
0
    def compute_kernels(self, final_ema):

        diag_ends = self.get_ends_of_diags(final_ema)

        S_init, D_init = self.VT(diag_ends[0][0])
        init_Kappa = self.sv**2 * S_init
        init_Theta = init_Kappa + self.sv**2 * diag_ends[0][1] * D_init
        init_list = T.concatenate([
            T.expand_dims(init_Kappa, axis=0),
            T.expand_dims(init_Theta, axis=0)
        ])

        def map_test(gp_rntk_sum, gp_rntk):
            S, D = self.VT(gp_rntk[0])
            ret1 = self.sv**2 * S
            ret2 = ret1 + self.sv**2 * gp_rntk[1] * D
            gp_rntk_sum = T.index_add(gp_rntk_sum, 0, ret1)
            gp_rntk_sum = T.index_add(gp_rntk_sum, 1, ret2)

            return gp_rntk_sum, gp_rntk_sum

        final_K_T, inter_results = T.scan(map_test,
                                          init=init_list,
                                          sequences=[diag_ends[1:]])
        return final_K_T
コード例 #4
0
 def RNTK_function(self):
     print(f"N, {self.N}, length, {self.length}")
     DATA = T.Placeholder((self.N, self.length), 'float32')
     RNTK,GP = self.RNTK_first(DATA[:,0])
     v, _ = T.scan(lambda a,b:self.RNTK_middle(a,b),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP]))
     RNTK_last,RNTK_avg = self.RNTK_output(v)
     f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg])
     return RNTK_last,RNTK_avg
コード例 #5
0
def RNTK_function(N,length,param):
    DATA = T.Placeholder((N, length), 'float32')
    RNTK,GP = RNTK_first(DATA[:,0], param['sigmaw'],param['sigmau'],param['sigmab'],param['sigmah'],param['L'], param['Lf'],param['sigmav'])
    v, _ = T.scan(lambda a,b:RNTK_middle(a,b,param['sigmaw'],param['sigmau'],param['sigmab'],param['L'], param['Lf'],param['sigmav']
                                         ),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP]))
    RNTK_last,RNTK_avg = RNTK_output(v, param['sigmav'],param['L'],param['Lf'],length)
    f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg])
    return f
コード例 #6
0
ファイル: RNTK_NEW.py プロジェクト: michaelsprintson/RNTK_UCI
    def create_func_for_diag(self,
                             dim1idx,
                             dim2idx,
                             function=False,
                             jmode=False):
        diag = self.make_inputs(dim1idx, dim2idx, jmode=jmode)
        # print('test')

        ## prev_vals - (2,1) - previous phi and lambda values
        ## idx - where we are on the diagonal
        ## d1idx - y value of first dimension diag start
        ## d2idx - x value of second dimension diag start
        ## d1ph - max value of first dimension
        ## d2ph - max value of second dimension
        bc = self.sh**2 * self.sw**2 * T.eye(
            self.n, self.n) + (self.su**2) * self.X + self.sb**2
        single_boundary_condition = T.expand_dims(bc, axis=0)
        # single_boundary_condition = T.expand_dims(T.Variable((bc), "float32", "boundary_condition"), axis = 0)
        boundary_condition = T.concatenate(
            [single_boundary_condition,
             single_boundary_condition])  #one for phi and lambda

        def fn(prev_vals, idx, Xph):

            ## change - xph must now index the dataset instead of being passed in

            # tiprime_iter = d1idx + idx
            # ti_iter = d2idx + idx
            prev_lambda = prev_vals[0]
            prev_phi = prev_vals[1]
            ## not boundary condition
            S, D = self.VT(prev_lambda)
            new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2  ## took out an X
            new_phi = new_lambda + self.sw**2 * prev_phi * D
            lambda_expanded = T.expand_dims(new_lambda, axis=0)
            phi_expanded = T.expand_dims(new_phi, axis=0)
            to_return = T.concatenate([lambda_expanded, phi_expanded])

            # jax.lax.cond(to_return.shape == (2,10,10), lambda _: print(f'{idx}, true'), lambda _: print(f'{idx}, false'), operand = None)

            return to_return, to_return

        last_ema, all_ema = T.scan(fn,
                                   init=boundary_condition,
                                   sequences=[diag],
                                   non_sequences=[self.X])

        expanded_ema = T.concatenate(
            [T.expand_dims(boundary_condition, axis=0), all_ema])
        print(expanded_ema)
        if function:
            f = symjax.function(diag, outputs=expanded_ema)
            return f
        else:
            return expanded_ema
コード例 #7
0
    def create_func_for_diag(self):

        # NEW_DATA = self.reorganize_data()

        ## change - bc should be a function that takes a passed in X?

        bc = self.sh**2 * self.sw**2 * T.eye(self.n, self.n) + (
            self.su**2) * self.X + self.sb**2  ## took out X ||
        single_boundary_condition = T.expand_dims(bc, axis=0)
        # single_boundary_condition = T.expand_dims(T.Variable((bc), "float32", "boundary_condition"), axis = 0)
        boundary_condition = T.concatenate(
            [single_boundary_condition, single_boundary_condition])
        self.boundary_condition = boundary_condition

        # self.save_vts = {}

        ## prev_vals - (2,1) - previous phi and lambda values
        ## idx - where we are on the diagonal
        def fn(prev_vals, idx, Xph):

            ## change - xph must now index the dataset instead of being passed in
            # x = Xph['indexed']
            # X = x*x[:, None]

            # tiprime_iter = d1idx + idx
            # ti_iter = d2idx + idx
            prev_lambda = prev_vals[0]
            prev_phi = prev_vals[1]
            ## not boundary condition
            S, D = self.VT(prev_lambda)
            new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2  ## took out an X
            new_phi = new_lambda + self.sw**2 * prev_phi * D
            lambda_expanded = T.expand_dims(new_lambda, axis=0)
            phi_expanded = T.expand_dims(new_phi, axis=0)

            to_return = T.concatenate([lambda_expanded, phi_expanded])

            if idx in self.ends_of_calced_diags:
                # self.save_vts[idx] = [S,D]
                return boundary_condition, to_return
            return to_return, to_return

        last_ema, all_ema = T.scan(
            fn,
            init=boundary_condition,
            sequences=[jnp.arange(0,
                                  sum(self.dim_lengths) - self.dim_num)],
            non_sequences=[self.X])
        # if fbool:

        #     return all_ema, f
        return self.compute_kernels(all_ema)
コード例 #8
0
    def compute_q(self, DATA):
        DATAT = T.transpose(DATA)
        xz = DATAT[0]
        # print(xz)
        init = self.su * 2 * T.linalg.norm(T.expand_dims(xz, axis = 0), ord = 2, axis = 0) + self.sb**2 + self.sh**2
        # print("init", init)
        # init = self.su * 2 * T.linalg.norm(xz, ord = 2) + self.sb**2 + self.sh**2 #make this a vectorized function 

        def scan_func(prevq, MINIDATAT): #MINIDATAT shuold be a vector of lenght N
            # print(MINIDATAT)
            # the trick to this one is to use the original VT
            S = self.alg1_VT(prevq) # -> M is K3 
            # S = prevq
            # newq = self.su * 2 * T.linalg.norm(T.expand_dims(xz, axis = 0), ord = 2, axis = 0) + self.sb**2 + self.sh**2
            newq = self.sw*2 * S + self.su * 2 * T.linalg.norm(T.expand_dims(MINIDATAT, axis = 0), ord = 2, axis = 0) + self.sb**2
            # print("newq", newq)
            return newq, newq
        
        last_ema, all_ema = T.scan(scan_func, init = init, sequences = [DATAT[1:]])
        return T.concatenate([T.expand_dims(init, axis = 0), all_ema])
コード例 #9
0
ファイル: plot_loops.py プロジェクト: SymJAX/SymJAX
# this function should output the new value of the carry as well as an
# additional output, in our case, the carry (EMA) is also what we want to
# output at each tiem step


def fn(at, xt, alpha):
    # the function first input is the carry, then are the (ordered)
    # values from sequences and non_sequences similar to Theano
    EMA = at * alpha + (1 - alpha) * xt
    return EMA, EMA


# the scan function will return the carry at each time steps (first arg.)
# as well as the last one, we also need to provide an init.
last_ema, all_ema = T.scan(
    fn, init=signal[0], sequences=[signal[1:]], non_sequences=[alpha]
)


f = symjax.function(signal, alpha, outputs=all_ema)

# generate a signal
x = np.cos(np.linspace(-3, 3, 512)) + np.random.randn(512) * 0.2

fig, ax = plt.subplots(3, 1, figsize=(3, 9))

for k, alpha in enumerate([0.1, 0.5, 0.9]):
    ax[k].plot(x, c="b")
    ax[k].plot(f(x, alpha), c="r")
    ax[k].set_title("EMA: {}".format(alpha))
    ax[k].set_xticks([])
コード例 #10
0
ファイル: layers.py プロジェクト: SymJAX/SymJAX
    def __init__(
        self,
        sequence,
        init_h,
        units,
        Wf=initializers.glorot_uniform,
        Uf=initializers.orthogonal,
        bf=T.zeros,
        Wi=initializers.glorot_uniform,
        Ui=initializers.orthogonal,
        bi=T.zeros,
        Wo=initializers.glorot_uniform,
        Uo=initializers.orthogonal,
        bo=T.zeros,
        Wc=initializers.glorot_uniform,
        Uc=initializers.orthogonal,
        bc=T.zeros,
        trainable_Wf=True,
        trainable_Uf=True,
        trainable_bf=True,
        trainable_Wi=True,
        trainable_Ui=True,
        trainable_bi=True,
        trainable_Wo=True,
        trainable_Uo=True,
        trainable_bo=True,
        trainable_Wc=True,
        trainable_Uc=True,
        trainable_bc=True,
        activation_g=nn.sigmoid,
        activation_c=T.tanh,
        activation_h=T.tanh,
        only_last=False,
        gate="minimal",
    ):

        self.create_variable("Wf",
                             Wf, (sequence.shape[2], units),
                             trainable=trainable_Wf)
        self.create_variable("Uf", Uf, (units, units), trainable=trainable_Uf)
        self.create_variable("bf", bf, (units, ), trainable=trainable_bf)

        self.create_variable("Wi",
                             Wi, (sequence.shape[2], units),
                             trainable=trainable_Wi)
        self.create_variable("Ui", Ui, (units, units), trainable=trainable_Ui)
        self.create_variable("bi", bi, (units, ), trainable=trainable_bi)

        self.create_variable("Wo",
                             Wo, (sequence.shape[2], units),
                             trainable=trainable_Wo)
        self.create_variable("Uo", Uo, (units, units), trainable=trainable_Uo)
        self.create_variable("bo", bo, (units, ), trainable=trainable_bo)

        self.create_variable("Wc",
                             Wc, (sequence.shape[2], units),
                             trainable=trainable_Wc)
        self.create_variable("Uc", Uc, (units, units), trainable=trainable_Uc)
        self.create_variable("bc", bc, (units, ), trainable=trainable_bc)

        def fn(*args):
            return self.gate(*args, activation_g, activation_c, activation_h)

        init = T.stack((init_h, T.zeros(init_h.shape, init_h.dtype)))
        last, output = T.scan(
            fn,
            init=init,
            sequences=[sequence.transpose((1, 0, 2))],
            non_sequences=[
                self.Wf,
                self.Uf,
                self.bf,
                self.Wi,
                self.Ui,
                self.bi,
                self.Wo,
                self.Uo,
                self.bo,
                self.Wc,
                self.Uc,
                self.bc,
            ],
        )

        if only_last:
            return last
        else:
            return output.transpose((1, 0, 2))
コード例 #11
0
ファイル: layers.py プロジェクト: SymJAX/SymJAX
    def __init__(
        self,
        sequence,
        init_h,
        units,
        Wh=initializers.glorot_uniform,
        Uh=initializers.orthogonal,
        bh=T.zeros,
        Wz=initializers.glorot_uniform,
        Uz=initializers.orthogonal,
        bz=T.zeros,
        Wr=initializers.glorot_uniform,
        Ur=initializers.orthogonal,
        br=T.zeros,
        trainable_Wh=True,
        trainable_Uh=True,
        trainable_bh=True,
        trainable_Wz=True,
        trainable_Uz=True,
        trainable_bz=True,
        trainable_Wr=True,
        trainable_Ur=True,
        trainable_br=True,
        activation=nn.sigmoid,
        phi=T.tanh,
        only_last=False,
        gate="minimal",
    ):

        Wh = create_variable("Wh",
                             Wh, (sequence.shape[2], units),
                             trainable=trainable_Wh)
        Uh = create_variable("Uh", Uh, (units, units), trainable=trainable_Uh)
        bh = create_variable("bh", bh, (units, ), trainable=trainable_bh)

        Wz = create_variable("Wz",
                             Wz, (sequence.shape[2], units),
                             trainable=trainable_Wz)
        Uz = create_variable("Uz", Uz, (units, units), trainable=trainable_Uz)
        bz = create_variable("bz", bz, (units, ), trainable=trainable_bz)

        if gate == "full":
            Wr = create_variable("Wr",
                                 Wr, (sequence.shape[2], units),
                                 trainable=trainable_Wr)
            Ur = create_variable("Ur",
                                 Ur, (units, units),
                                 trainable=trainable_Ur)
            br = create_variable("br", br, (units, ), trainable=trainable_br)

        if gate == "minimal":

            def fn(*args):
                return GRU.minimal_gate(*args, activation, phi)

            last, output = T.scan(
                fn,
                init=init_h,
                sequences=[sequence.transpose((1, 0, 2))],
                non_sequences=[Wh, Uh, bh, Wz, Uz, bz],
            )

        elif gate == "full":

            def fn(*args):
                return GRU.full_gate(*args, activation, phi)

            last, output = T.scan(
                fn,
                init=init_h,
                sequences=[sequence.transpose((1, 0, 2))],
                non_sequences=[Wh, Uh, bh, Wz, Uz, bz, Wr, Ur, br],
            )

        if only_last:
            return last
        else:
            return output.transpose((1, 0, 2))
コード例 #12
0
ファイル: control_flow.py プロジェクト: ml-lab/SymJAX
import numpy as np
import sys
sys.path.insert(0, "../")

import symjax
import symjax.tensor as T

# map
xx = T.ones(10)
a = T.map(lambda a: a * 2, xx)
g = symjax.gradients(a.sum(), xx)[0]
f = symjax.function(outputs=[a, g])

# scan
xx = T.ones(10) * 2
a = T.scan(lambda c, x: (c * x, c * x), T.ones(1), xx)
g = symjax.gradients(a[1][-1], xx)[0]
f = symjax.function(outputs=[a, g])

# scan with updates
xx = T.range(5)
uu = T.ones((10, 2))
vvar = T.Variable(T.zeros((10, 2)))
vv = T.index_add(vvar, 1, 1)
a = T.scan(lambda c, x, p: (T.index_update(c, x, p[x]), 1), vv, xx, [vv])
#a = T.scan(lambda c, x: (c*x,c*x), T.ones(1), xx)
#a = T.scan(lambda c, x: (T.square(c),c[0]), uu, xx)
#g = symjax.gradients(a[1][-1],xx)
f = symjax.function(outputs=a[0], updates={vvar: vvar + 1})
print(f(), f(), f())
asdf
コード例 #13
0
    def forward(
        self,
        sequence,
        init_h,
        units,
        Wh=initializers.he,
        Uh=initializers.he,
        bh=T.zeros,
        Wz=initializers.he,
        Uz=initializers.he,
        bz=T.zeros,
        Wr=initializers.he,
        Ur=initializers.he,
        br=T.zeros,
        trainable_Wh=True,
        trainable_Uh=True,
        trainable_bh=True,
        trainable_Wz=True,
        trainable_Uz=True,
        trainable_bz=True,
        trainable_Wr=True,
        trainable_Ur=True,
        trainable_br=True,
        activation=nn.sigmoid,
        phi=T.tanh,
        only_last=False,
        gate="minimal",
    ):

        self.create_variable("Wh",
                             Wh, (sequence.shape[2], units),
                             trainable=trainable_Wh)
        self.create_variable("Uh", Uh, (units, units), trainable=trainable_Uh)
        self.create_variable("bh", bh, (units), trainable=trainable_bh)

        self.create_variable("Wz",
                             Wz, (sequence.shape[2], units),
                             trainable=trainable_Wz)
        self.create_variable("Uz", Uz, (units, units), trainable=trainable_Uz)
        self.create_variable("bz", bz, (units), trainable=trainable_bz)

        if gate == "full":
            self.create_variable("Wr",
                                 Wr, (sequence.shape[2], units),
                                 trainable=trainable_Wr)
            self.create_variable("Ur",
                                 Ur, (units, units),
                                 trainable=trainable_Ur)
            self.create_variable("br", br, (units), trainable=trainable_br)

        if gate == "minimal":

            def fn(h, x, Wh, Uh, bh, Wz, Uz, bz):
                return self.minimal_gate(h, x, Wh, Uh, bh, Wz, Uz, bz,
                                         activation, phi)

        elif gate == "full":

            def fn(h, x, Wh, Uh, bh, Wz, Uz, bz):
                return self.full_gate(h, x, Wh, Uh, bh, Wz, Uz, bz, Wr, Ur, br,
                                      activation, phi)

        last, output = T.scan(
            fn,
            init=init_h,
            sequences=[sequence.transpose((1, 0, 2))],
            non_sequences=[self.W, self.H, self.b],
        )
        if only_last:
            return last
        else:
            return output.transpose((1, 0, 2))
コード例 #14
0
    def create_func_for_diag(self):

        NEW_DATA = self.reorganize_data()
        NEW_DATA_ATTACHED = jnp.array(list(zip(NEW_DATA[:-1], NEW_DATA[1:])))
        # print(NEW_DATA_ATTACHED)

        x = self.DATA[:,0]
        X = x*x[:, None]

        boundary_condition = self.make_boundary_condition(X)
        
        # #lets create the inital kernels - should be starting with top right
        temp_K, temp_theta = self.compute_kernels(self.qt, self.qtprime, 0, self.dim_1, boundary_condition, boundary_condition, T.empty((self.N, self.N)), T.empty((self.N, self.N)))
        init_K, init_theta = self.compute_kernels(self.qt, self.qtprime, 0, self.dim_1 - 1, boundary_condition, boundary_condition, temp_K, temp_theta)

        initial_conditions = create_T_list([boundary_condition, boundary_condition, init_K, init_theta])
        # initial_conditions = create_T_list([T.empty((self.N, self.N)), T.empty((self.N, self.N)), T.empty((self.N, self.N)), T.empty((self.N, self.N)) + 2])
        
        ## prev_vals - (4,self.N,self.N) - previous phi, lambda, and the two kernel values
        ## idx - where we are on the diagonal
        def fn(prev_vals, idx, data_idxs, DATAPH, DATAPRIMEPH, qtph, qtprimeph):

            xTP = DATAPRIMEPH[data_idxs[0][0]] #N1
            xT = DATAPH[data_idxs[0][1]] #N2
            xINNER = T.inner(xT, xTP) #N1 x N2

            prev_lambda = prev_vals[0]
            prev_phi = prev_vals[1]
            prev_K = prev_vals[2]
            prev_theta = prev_vals[3]
            ## not boundary condition
            # print(qtph[data_idxs[0][1] - 1])
            # print(qtprimeph[data_idxs[0][0] - 1])
            S, D = self.alg2_VT(qtph[data_idxs[0][1] - 1], qtprimeph[data_idxs[0][0] - 1] ,prev_lambda)
            new_lambda = self.sw ** 2 * S + self.su ** 2 * xINNER + self.sb ** 2 ## took out an X
            new_phi = new_lambda + self.sw ** 2 * prev_phi * D
            # new_phi = prev_phi
            # new_lambda = prev_lambda

            #compute kernels
            S_kernel, D_kernel = self.alg2_VT(qtph[data_idxs[0][1]], qtprimeph[data_idxs[0][0]], new_lambda)
            new_K = prev_K + self.sv**2 * S_kernel#get current lamnda, get current qtph and qtprimeph
            new_theta = prev_theta + self.sv**2 * S_kernel + self.sv**2 * D_kernel * new_phi #TODO
            # ret_K = prev_K
            # ret_theta = prev_theta

            equal_check = lambda e: T.equal(idx, e)

            equal_result = sum(np.vectorize(equal_check)(self.ends_of_calced_diags)) > 0

            def true_f(k,t, qp, gph, dataph, dataprimeph, di): 
                xTP_NEXT = dataprimeph[di[1][0]]
                xT_NEXT = dataph[di[1][1]]
                xINNER_NEXT = T.inner(xT_NEXT, xTP_NEXT)
                new_bc = self.make_boundary_condition(xINNER_NEXT)
                ret_lambda = ret_phi = new_bc 

                S_bc_kernel, D_bc_kernel = self.alg2_VT(qp[di[1][1]], gph[di[1][0]], ret_lambda)
                ret_K = k + self.sv**2 * S_bc_kernel#get current lamnda, get current qtph and qtprimeph
                ret_theta = t + self.sv**2 * S_bc_kernel + self.sv**2 * D_bc_kernel * ret_phi #TODO

                return ret_lambda, ret_phi, ret_K, ret_theta
            false_f = lambda l,p,k,t: (l,p,k,t)

            ret_lambda, ret_phi, ret_K, ret_theta = T.cond(equal_result, true_f, false_f, [new_K, new_theta, qtph, qtprimeph, DATAPH, DATAPRIMEPH, data_idxs], [new_lambda, new_phi, new_K, new_theta])
            
            to_carry = create_T_list([ret_lambda, ret_phi, ret_K, ret_theta])
            # print('got poast second create list')
            
            return to_carry, np.array(())
        
        carry_ema, _ = T.scan(
            fn, init =  initial_conditions, sequences=[jnp.arange(0, sum(self.dim_lengths) - self.dim_num), NEW_DATA_ATTACHED], non_sequences=[T.transpose(self.DATA), T.transpose(self.DATAPRIME), self.qt, self.qtprime]
        )
        
        return carry_ema[2:4] ## so here, the output will be the added up kernels except for the boundary conditions
コード例 #15
0
ファイル: scan.py プロジェクト: ml-lab/SymJAX
import sys
sys.path.insert(0, "../")
import symjax as sj
import symjax.tensor as T
import numpy as np

__author__ = "Randall Balestriero"

# example of cumulative sum


def func(carry, x):
    return carry + 1, 0


output, _ = T.scan(func, T.zeros(1), T.ones(10), length=10)

f = sj.function(outputs=output)
print(f())
# [10.]

# example of simple RNN

w = T.Placeholder((3, 10), 'float32')
h = T.random.randn((3, 3))
b = T.random.randn((3, ))
t_steps = 100
X = T.random.randn((t_steps, 10))


def rnn_cell(carry, x, w):