Exemplo n.º 1
0
def test_compatible_with_other_implementation():
    X = T.ftensor3('X')
    W = T.fmatrix('W')
    V_h = T.fmatrix('V_h')
    b = T.fvector('b')
    c = T.fmatrix('c')  #initial state
    i = T.matrix('i', dtype='int8')
    #Y, _, _ = LSTMOp2Instance(V_h, c, b, i, X, W)
    Y, _, _ = LSTMOpInstance(T.dot(X, W) + b, V_h, c, i)
    DX = T.grad(Y.sum(), X)
    DW = T.grad(Y.sum(), W)
    DV_h = T.grad(Y.sum(), V_h)
    Db = T.grad(Y.sum(), b)

    n_T = 5
    n_batch = 4
    n_inp_dim = 3
    n_cells = 8
    X_val = numpy.random.ranf((n_T, n_batch, n_inp_dim)).astype('float32')
    W_val = numpy.random.ranf((n_inp_dim, 4 * n_cells)).astype('float32')
    V_h_val = numpy.random.ranf((n_cells, 4 * n_cells)).astype('float32')
    b_val = numpy.random.ranf((4 * n_cells, )).astype('float32')
    c_val = numpy.random.ranf((n_batch, n_cells)).astype('float32')
    y0_val = numpy.zeros((n_batch, n_cells), dtype='float32')
    i_val = numpy.ones((n_T, n_batch), dtype='int8')

    def _step(x_t, c_tm1, y_tm1):
        z_t = T.dot(x_t, W) + T.dot(y_tm1, V_h) + b
        partition = z_t.shape[1] / 4
        ingate = T.nnet.sigmoid(z_t[:, :partition])
        forgetgate = T.nnet.sigmoid(z_t[:, partition:2 * partition])
        outgate = T.nnet.sigmoid(z_t[:, 2 * partition:3 * partition])
        input = T.tanh(z_t[:, 3 * partition:4 * partition])
        c_t = forgetgate * c_tm1 + ingate * input
        y_t = outgate * T.tanh(c_t)
        return c_t, y_t

    [state, Y2], _ = theano.scan(_step,
                                 sequences=[X],
                                 outputs_info=[c, y0_val])

    DX2 = T.grad(Y2.sum(), X)
    DW2 = T.grad(Y2.sum(), W)
    DV_h2 = T.grad(Y2.sum(), V_h)
    Db2 = T.grad(Y2.sum(), b)

    f = theano.function(
        inputs=[X, W, V_h, c, b, i],
        outputs=[Y, Y2, DX, DX2, DW, DW2, DV_h, DV_h2, Db, Db2])
    Y_val, Y2_val, DX_val, DX2_val, DW_val, DW2_val, DV_h_val, DV_h2_val, Db_val, Db2_val = f(
        X_val, W_val, V_h_val, c_val, b_val, i_val)
    vals_fast = [Y_val, DX_val, DW_val, DV_h_val, Db_val]
    vals_fast = [numpy.asarray(A, dtype='float32') for A in vals_fast]
    vals_simple = [Y2_val, DX2_val, DW2_val, DV_h2_val, Db2_val]

    names = ["Y_val", "DX_val", "DW_val", "DV_h_val", "Db_val"]
    for f, s, n in zip(vals_fast, vals_simple, names):
        assert numpy.allclose(f, s, rtol=3e-5), (n, f, s)

    print "sucess"
Exemplo n.º 2
0
def test_bwd_pass_compatible_with_OpLSTM():
    Z = T.ftensor3('Z')
    W_re = T.fmatrix('W_re')
    W_att_in = T.fmatrix('W_att_in')
    c = T.fmatrix('c')  #initial state
    y0 = T.fmatrix('y0')  #initial activation
    i = T.matrix('i', dtype='int8')
    Y, H, d = LSTMCustomTestOpNoInplaceInstance(Z, c, y0, i, W_re, W_att_in)
    W_re_modified = W_re + W_att_in
    Z_modified = T.inc_subtensor(Z[0], T.dot(y0, W_re_modified))
    Y2, H2, d2 = LSTMOpInstance(Z_modified, W_re_modified, c, i)

    cost = Y.sum()
    DZ = T.grad(cost, Z)
    DW_re = T.grad(cost, W_re)
    DW_att_in = T.grad(cost, W_att_in)
    Dc = T.grad(cost, c)
    Dy0 = T.grad(cost, y0)
    cost2 = Y2.sum()
    DZ2 = T.grad(cost2, Z)
    DW_re2 = T.grad(cost2, W_re)
    DW_att_in2 = T.grad(cost2, W_att_in)
    Dc2 = T.grad(cost2, c)
    Dy02 = T.grad(cost2, y0)

    f = theano.function(inputs=[Z, c, y0, i, W_re, W_att_in],
                        outputs=[DZ, DW_re, Dc, Dy0, DW_att_in])
    g = theano.function(inputs=[Z, W_re, c, y0, i, W_att_in],
                        outputs=[DZ2, DW_re2, Dc2, Dy02, DW_att_in2])

    n_T = 5
    n_batch = 4
    n_inp_dim = 3
    n_cells = 8
    numpy.random.seed(1234)
    Z_val = numpy.random.ranf((n_T, n_batch, 4 * n_cells)).astype('float32')
    W_re_val = numpy.random.ranf((n_cells, 4 * n_cells)).astype('float32')
    W_att_in_val = numpy.random.ranf((n_cells, 4 * n_cells)).astype('float32')
    #W_att_in_val = numpy.zeros((n_cells, 4 * n_cells)).astype('float32')
    c_val = numpy.random.ranf((n_batch, n_cells)).astype('float32')
    y0_val = numpy.random.ranf((n_batch, n_cells)).astype('float32')
    #i_val = numpy.ones((n_T, n_batch), dtype='int8')
    i_val = numpy.array(
        [[1, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 1], [0, 0, 1, 0, 0]],
        dtype='int8').T

    vals = f(Z_val, c_val, y0_val, i_val, W_re_val, W_att_in_val)
    DZ_val, DW_re_val, Dc_val, Dy0_val, DW_att_in_val = [
        numpy.asarray(x) for x in vals
    ]
    vals2 = g(Z_val, W_re_val, c_val, y0_val, i_val, W_att_in_val)
    DZ2_val, DW_re2_val, Dc2_val, Dy02_val, DW_att_in2_val = [
        numpy.asarray(x) for x in vals2
    ]
    assert numpy.allclose(DZ_val, DZ2_val, atol=5e-7, rtol=1e-4)
    assert numpy.allclose(DW_re_val, DW_re2_val, atol=5e-7, rtol=1e-4)
    assert numpy.allclose(Dc_val, Dc2_val)
    assert numpy.allclose(Dy0_val, Dy02_val)
    assert numpy.allclose(DW_att_in_val, DW_att_in2_val, atol=5e-7, rtol=1e-4)
    print "success"
Exemplo n.º 3
0
def test_compatible_with_other_implementation():
  X = T.ftensor3('X')
  W = T.fmatrix('W')
  V_h = T.fmatrix('V_h')
  b = T.fvector('b')
  c = T.fmatrix('c') #initial state
  i = T.matrix('i',dtype='int8')
  #Y, _, _ = LSTMOp2Instance(V_h, c, b, i, X, W)
  Y, _, _ = LSTMOpInstance(T.dot(X,W) + b, V_h, c, i)
  DX = T.grad(Y.sum(), X)
  DW = T.grad(Y.sum(), W)
  DV_h = T.grad(Y.sum(), V_h)
  Db = T.grad(Y.sum(), b)

  n_T = 5
  n_batch = 4
  n_inp_dim = 3
  n_cells = 8
  X_val = numpy.random.ranf((n_T,n_batch,n_inp_dim)).astype('float32')
  W_val = numpy.random.ranf((n_inp_dim, 4 * n_cells)).astype('float32')
  V_h_val = numpy.random.ranf((n_cells, 4 * n_cells)).astype('float32')
  b_val = numpy.random.ranf((4 * n_cells,)).astype('float32')
  c_val = numpy.random.ranf((n_batch, n_cells)).astype('float32')
  y0_val = numpy.zeros((n_batch, n_cells), dtype='float32')
  i_val = numpy.ones((n_T, n_batch), dtype='int8')

  def _step(x_t, c_tm1, y_tm1):
    z_t = T.dot(x_t, W) + T.dot(y_tm1, V_h) + b
    partition = z_t.shape[1] / 4
    ingate = T.nnet.sigmoid(z_t[:,:partition])
    forgetgate = T.nnet.sigmoid(z_t[:,partition:2*partition])
    outgate = T.nnet.sigmoid(z_t[:,2*partition:3*partition])
    input = T.tanh(z_t[:,3*partition:4*partition])
    c_t = forgetgate * c_tm1 + ingate * input
    y_t = outgate * T.tanh(c_t)
    return c_t, y_t

  [state, Y2], _ = theano.scan(_step, sequences=[X],
                          outputs_info=[c, y0_val])

  DX2 = T.grad(Y2.sum(), X)
  DW2 = T.grad(Y2.sum(), W)
  DV_h2 = T.grad(Y2.sum(), V_h)
  Db2 = T.grad(Y2.sum(), b)

  f = theano.function(inputs=[X, W, V_h, c, b, i], outputs=[Y, Y2, DX, DX2, DW, DW2, DV_h, DV_h2, Db, Db2])
  Y_val, Y2_val, DX_val, DX2_val, DW_val, DW2_val, DV_h_val, DV_h2_val, Db_val, Db2_val = f(X_val, W_val, V_h_val, c_val, b_val, i_val)
  vals_fast = [Y_val, DX_val, DW_val, DV_h_val, Db_val]
  vals_fast = [numpy.asarray(A, dtype='float32') for A in vals_fast]
  vals_simple = [Y2_val, DX2_val, DW2_val, DV_h2_val, Db2_val]

  names = ["Y_val", "DX_val", "DW_val", "DV_h_val", "Db_val"]
  for f, s, n in zip(vals_fast, vals_simple, names):
    assert numpy.allclose(f, s, rtol=3e-5), (n, f, s)

  print "sucess"
Exemplo n.º 4
0
def test_bwd_pass_compatible_with_OpLSTM():
  Z = T.ftensor3('Z')
  W_re = T.fmatrix('W_re')
  W_att_in = T.fmatrix('W_att_in')
  c = T.fmatrix('c') #initial state
  y0 = T.fmatrix('y0') #initial activation
  i = T.matrix('i',dtype='int8')
  Y, H, d = LSTMCustomTestOpNoInplaceInstance(Z, c, y0, i, W_re, W_att_in)
  W_re_modified = W_re + W_att_in
  Z_modified = T.inc_subtensor(Z[0], T.dot(y0,W_re_modified))
  Y2, H2, d2 = LSTMOpInstance(Z_modified, W_re_modified, c, i)

  cost = Y.sum()
  DZ = T.grad(cost, Z)
  DW_re = T.grad(cost, W_re)
  DW_att_in = T.grad(cost, W_att_in)
  Dc = T.grad(cost, c)
  Dy0 = T.grad(cost, y0)
  cost2 = Y2.sum()
  DZ2 = T.grad(cost2, Z)
  DW_re2 = T.grad(cost2, W_re)
  DW_att_in2 = T.grad(cost2, W_att_in)
  Dc2 = T.grad(cost2, c)
  Dy02 = T.grad(cost2, y0)

  f = theano.function(inputs=[Z, c, y0, i, W_re, W_att_in], outputs=[DZ, DW_re, Dc, Dy0, DW_att_in])
  g = theano.function(inputs=[Z, W_re, c, y0, i, W_att_in], outputs=[DZ2, DW_re2, Dc2, Dy02, DW_att_in2])

  n_T = 5
  n_batch = 4
  n_inp_dim = 3
  n_cells = 8
  numpy.random.seed(1234)
  Z_val = numpy.random.ranf((n_T,n_batch,4*n_cells)).astype('float32')
  W_re_val = numpy.random.ranf((n_cells, 4 * n_cells)).astype('float32')
  W_att_in_val = numpy.random.ranf((n_cells, 4 * n_cells)).astype('float32')
  #W_att_in_val = numpy.zeros((n_cells, 4 * n_cells)).astype('float32')
  c_val = numpy.random.ranf((n_batch, n_cells)).astype('float32')
  y0_val = numpy.random.ranf((n_batch, n_cells)).astype('float32')
  #i_val = numpy.ones((n_T, n_batch), dtype='int8')
  i_val = numpy.array([[1,1,1,1,1], [0,0,1,1,1], [0,0,1,1,1], [0,0,1,0,0]], dtype='int8').T

  vals = f(Z_val, c_val, y0_val, i_val, W_re_val, W_att_in_val)
  DZ_val, DW_re_val, Dc_val, Dy0_val, DW_att_in_val = [numpy.asarray(x) for x in vals]
  vals2 = g(Z_val, W_re_val, c_val, y0_val, i_val, W_att_in_val)
  DZ2_val, DW_re2_val, Dc2_val, Dy02_val, DW_att_in2_val = [numpy.asarray(x) for x in vals2]
  assert numpy.allclose(DZ_val, DZ2_val, atol=5e-7, rtol=1e-4)
  assert numpy.allclose(DW_re_val, DW_re2_val, atol=5e-7, rtol=1e-4)
  assert numpy.allclose(Dc_val, Dc2_val)
  assert numpy.allclose(Dy0_val, Dy02_val)
  assert numpy.allclose(DW_att_in_val, DW_att_in2_val, atol=5e-7, rtol=1e-4)
  print("success")
Exemplo n.º 5
0
def test_fwd_pass_compatible_with_OpLSTM():
  Z = T.ftensor3('Z')
  W_re = T.fmatrix('W_re')
  W_att_in = T.fmatrix('W_att_in')
  c = T.fmatrix('c') #initial state
  y0 = T.fmatrix('y0') #initial activation
  i = T.matrix('i',dtype='int8')

  Y, H, d = LSTMCustomTestOpNoInplaceInstance(Z, c, y0, i, W_re, W_att_in)
  W_re_modified = W_re + W_att_in
  Z_modified = T.inc_subtensor(Z[0], T.dot(y0,W_re_modified))
  Y2, H2, d2 = LSTMOpInstance(Z_modified, W_re_modified, c, i)

  f = theano.function(inputs=[Z, c, y0, i, W_re, W_att_in], outputs=Y)
  g = theano.function(inputs=[Z, W_re, c, y0, i, W_att_in], outputs=Y2)

  n_T = 5
  n_batch = 4
  n_inp_dim = 3
  n_cells = 8
  numpy.random.seed(1234)
  Z_val = numpy.random.ranf((n_T,n_batch,4*n_cells)).astype('float32')
  W_re_val = numpy.random.ranf((n_cells, 4 * n_cells)).astype('float32')
  W_att_in_val = numpy.random.ranf((n_cells, 4 * n_cells)).astype('float32')
  c_val = numpy.random.ranf((n_batch, n_cells)).astype('float32')
  y0_val = numpy.random.ranf((n_batch, n_cells)).astype('float32')
  #i_val = numpy.ones((n_T, n_batch), dtype='int8')
  i_val = numpy.array([[1,1,1,1,1], [0,0,1,1,1], [0,0,1,1,1], [0,0,1,0,0]], dtype='int8').T

  Y_val = numpy.asarray(f(Z_val, c_val, y0_val, i_val, W_re_val, W_att_in_val))
  Y2_val = numpy.asarray(g(Z_val, W_re_val, c_val, y0_val, i_val, W_att_in_val))
  assert numpy.allclose(Y_val, Y2_val)
  print("success")
Exemplo n.º 6
0
 def scan(self,
          x,
          z,
          non_sequences,
          i,
          outputs_info,
          W_re,
          W_in,
          b,
          go_backwards=False,
          truncate_gradient=-1):
     z = T.inc_subtensor(z[-1 if go_backwards else 0],
                         T.dot(outputs_info[0], W_re))
     result = LSTMOpInstance(z[::-(2 * go_backwards - 1)], W_re,
                             outputs_info[1], i[::-(2 * go_backwards - 1)])
     return [result[0], result[2].dimshuffle('x', 0, 1)]
Exemplo n.º 7
0
def test_compatible_with_other_implementation_and_index_vector():
    X = T.ftensor3('X')
    W = T.fmatrix('W')
    V_h = T.fmatrix('V_h')
    b = T.fvector('b')
    c = T.fmatrix('c')  #initial state
    i = T.matrix('i', dtype='int8')
    #Z, _, h = LSTMOp2Instance(V_h, c, b, i, X, W)
    Z, _, h = LSTMOpInstance(T.dot(X, W) + b, V_h, c, i)
    obj = Z.sum() + h.sum()
    DX = T.grad(obj, X)
    DW = T.grad(obj, W)
    DV_h = T.grad(obj, V_h)
    Db = T.grad(obj, b)
    X_val_mat0 = 0.1 * numpy.array([[1, 2, 3], [4, 5, 6]], dtype='float32')
    X_val_mat1 = 0.1 * numpy.array([[5, 1, 8], [7, 0, 1]], dtype='float32')
    X_val_mat2 = 0.1 * numpy.array([[2, 1, 1], [-7, 0, -1]], dtype='float32')
    X_val = numpy.zeros((3, 2, 3), dtype='float32')
    X_val[0, :, :] = X_val_mat0
    X_val[1, :, :] = X_val_mat1
    X_val[2, :, :] = X_val_mat2
    #should be divisable by 4 for lstm, attention: note the .T
    W_val = 0.1 * numpy.array(
        [[3, 1, 2], [4, 8, 0], [7, 7, 1], [4, 2, -5], [6, -1, -2], [-4, 8, 0],
         [-7, 2, 1], [4, -2, -5], [6, 5, -2], [-4, 8, -6], [-7, 3, -1],
         [4, 2, -5]],
        dtype='float32').T
    #(for lstm) size 1/4th
    V_h_val = 0.1 * numpy.array(
        [[1, 3, 5], [2, -1, -1], [4, 8, -5], [0, -2, 3], [7, 7, 7], [1, 2, 3],
         [5, 2, 1], [-4, 8, -4], [-3, 7, -7], [2, -2, -3], [-5, 2, 1],
         [-4, -5, -4]],
        dtype='float32').T
    b_val = 0.1 * numpy.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
                              dtype='float32')
    c_val = numpy.zeros((2, 3), dtype='float32')
    i_vals = [
        numpy.array([[0, 1], [1, 1], [1, 0]], dtype='int8'),
        numpy.array([[0, 1], [0, 1], [0, 1]], dtype='int8'),
        numpy.ones((3, 2), dtype='int8')
    ]  #layout of index vector: time x batch

    o_output = T.as_tensor(numpy.ones((3, ), dtype='float32'))
    o_h = T.as_tensor(numpy.ones((3, ), dtype='float32'))

    def _step(x_t, i_t, c_tm1, y_tm1):
        #z_t = T.dot(x_t, W) + T.dot(y_tm1, V_h) + b
        z_t = x_t + T.dot(y_tm1, V_h)
        partition = z_t.shape[1] / 4
        ingate = T.nnet.sigmoid(z_t[:, :partition])
        forgetgate = T.nnet.sigmoid(z_t[:, partition:2 * partition])
        outgate = T.nnet.sigmoid(z_t[:, 2 * partition:3 * partition])
        input = T.tanh(z_t[:, 3 * partition:4 * partition])
        c_t = forgetgate * c_tm1 + ingate * input
        y_t = outgate * T.tanh(c_t)
        i_output = T.outer(i_t, o_output)
        i_h = T.outer(i_t, o_h)
        return c_t * i_h + c_tm1 * (1 - i_h), y_t * i_output

    #[state, Z2], _ = theano.scan(_step, sequences=[X, i],
    #                        outputs_info=[c, c])
    [state, Z2], _ = theano.scan(_step,
                                 sequences=[T.dot(X, W) + b, i],
                                 outputs_info=[c, c])

    h2 = state[-1]
    obj2 = Z2.sum() + h2.sum()
    DX2 = T.grad(obj2, X)
    DW2 = T.grad(obj2, W)
    DV_h2 = T.grad(obj2, V_h)
    Db2 = T.grad(obj2, b)

    f = theano.function(
        inputs=[X, W, V_h, c, b, i],
        outputs=[Z, Z2, DX, DX2, DW, DW2, DV_h, DV_h2, Db, Db2, h2, h])
    for i_val in i_vals:
        Z_val, Z2_val, DX_val, DX2_val, DW_val, DW2_val, DV_h_val, DV_h2_val, \
          Db_val, Db2_val, h2_val, h_val = f(X_val, W_val, V_h_val, c_val, b_val, i_val)
        vals_fast_fwd = [Z_val, h_val]
        vals_fast_fwd = [
            numpy.asarray(A, dtype='float32') for A in vals_fast_fwd
        ]
        vals_fast_grad = [DX_val, DW_val, DV_h_val, Db_val]
        vals_fast_grad = [
            numpy.asarray(A, dtype='float32') for A in vals_fast_grad
        ]
        vals_simple_fwd = [Z2_val, h2_val]
        vals_simple_grad = [DX2_val, DW2_val, DV_h2_val, Db2_val]
        #print vals_fast_fwd
        #print vals_simple_fwd
        for fa, sl in zip(vals_fast_fwd, vals_simple_fwd):
            assert numpy.allclose(fa, sl)
        for fa, sl in zip(vals_fast_grad, vals_simple_grad):
            assert numpy.allclose(fa, sl)
    #print numpy.asarray(Z_val, 'float32')
    #print Z2_val
    print "success"
Exemplo n.º 8
0
def test_compatible_with_other_implementation_and_index_vector():
  X = T.ftensor3('X')
  W = T.fmatrix('W')
  V_h = T.fmatrix('V_h')
  b = T.fvector('b')
  c = T.fmatrix('c') #initial state
  i = T.matrix('i', dtype='int8')
  #Z, _, h = LSTMOp2Instance(V_h, c, b, i, X, W)
  Z, _, h = LSTMOpInstance(T.dot(X,W) + b, V_h, c, i)
  obj = Z.sum() + h.sum()
  DX = T.grad(obj, X)
  DW = T.grad(obj, W)
  DV_h = T.grad(obj, V_h)
  Db = T.grad(obj, b)
  X_val_mat0 = 0.1 * numpy.array([[1,2,3], [4,5,6]], dtype='float32')
  X_val_mat1 = 0.1 * numpy.array([[5,1,8], [7,0,1]], dtype='float32')
  X_val_mat2 = 0.1 * numpy.array([[2,1,1], [-7,0,-1]], dtype='float32')
  X_val = numpy.zeros((3,2,3), dtype='float32')
  X_val[0, :, :] = X_val_mat0
  X_val[1, :, :] = X_val_mat1
  X_val[2, :, :] = X_val_mat2
  #should be divisable by 4 for lstm, attention: note the .T
  W_val = 0.1 * numpy.array([[3,1,2], [4,8,0], [7,7,1], [4,2,-5],
                             [6,-1,-2], [-4,8,0], [-7,2,1], [4,-2,-5],
                             [6,5,-2], [-4,8,-6], [-7,3,-1], [4,2,-5]], dtype='float32').T
  #(for lstm) size 1/4th
  V_h_val = 0.1 * numpy.array([[1,3,5], [2,-1,-1], [4, 8,-5], [0,-2,3],
                               [7,7,7], [1,2,3], [5,2,1], [-4,8,-4],
                               [-3,7,-7], [2,-2,-3], [-5,2,1], [-4,-5,-4]],
                              dtype='float32').T
  b_val = 0.1 * numpy.array([1,2,3,4,5,6,7,8,9,10,11,12], dtype='float32')
  c_val = numpy.zeros((2,3), dtype='float32')
  i_vals = [numpy.array([[0,1], [1,1], [1,0]], dtype='int8'),
            numpy.array([[0,1], [0,1], [0,1]], dtype='int8'),
            numpy.ones((3,2), dtype='int8')] #layout of index vector: time x batch

  o_output = T.as_tensor(numpy.ones((3,), dtype='float32'))
  o_h = T.as_tensor(numpy.ones((3,), dtype='float32'))
  def _step(x_t, i_t, c_tm1, y_tm1):
    #z_t = T.dot(x_t, W) + T.dot(y_tm1, V_h) + b
    z_t = x_t + T.dot(y_tm1, V_h)
    partition = z_t.shape[1] / 4
    ingate = T.nnet.sigmoid(z_t[:,:partition])
    forgetgate = T.nnet.sigmoid(z_t[:,partition:2*partition])
    outgate = T.nnet.sigmoid(z_t[:,2*partition:3*partition])
    input = T.tanh(z_t[:,3*partition:4*partition])
    c_t = forgetgate * c_tm1 + ingate * input
    y_t = outgate * T.tanh(c_t)
    i_output = T.outer(i_t, o_output)
    i_h = T.outer(i_t, o_h)
    return c_t * i_h + c_tm1 * (1 - i_h), y_t * i_output

  #[state, Z2], _ = theano.scan(_step, sequences=[X, i],
  #                        outputs_info=[c, c])
  [state, Z2], _ = theano.scan(_step, sequences=[T.dot(X,W)+b, i],
                          outputs_info=[c, c])

  h2 = state[-1]
  obj2 = Z2.sum() + h2.sum()
  DX2 = T.grad(obj2, X)
  DW2 = T.grad(obj2, W)
  DV_h2 = T.grad(obj2, V_h)
  Db2 = T.grad(obj2, b)

  f = theano.function(inputs=[X, W, V_h, c, b, i], outputs=[Z, Z2, DX, DX2, DW, DW2, DV_h, DV_h2, Db, Db2, h2, h])
  for i_val in i_vals:
    Z_val, Z2_val, DX_val, DX2_val, DW_val, DW2_val, DV_h_val, DV_h2_val, \
      Db_val, Db2_val, h2_val, h_val = f(X_val, W_val, V_h_val, c_val, b_val, i_val)
    vals_fast_fwd = [Z_val, h_val]
    vals_fast_fwd = [numpy.asarray(A, dtype='float32') for A in vals_fast_fwd]
    vals_fast_grad = [DX_val, DW_val, DV_h_val, Db_val]
    vals_fast_grad = [numpy.asarray(A, dtype='float32') for A in vals_fast_grad]
    vals_simple_fwd = [Z2_val, h2_val]
    vals_simple_grad = [DX2_val, DW2_val, DV_h2_val, Db2_val]
    #print vals_fast_fwd
    #print vals_simple_fwd
    for fa, sl in zip(vals_fast_fwd, vals_simple_fwd):
      assert numpy.allclose(fa, sl)
    for fa, sl in zip(vals_fast_grad, vals_simple_grad):
      assert numpy.allclose(fa, sl)
  #print numpy.asarray(Z_val, 'float32')
  #print Z2_val
  print "success"