def test_strided_elementwise_inplace():    
    from brainstorm.handlers import PyCudaHandler
    _h = PyCudaHandler()
    rdm = np.random.RandomState(1345)
    
    def get_rdm_array(shape, dims):
        if dims == 2: return rdm.randn(shape[0],shape[1])
        elif dims == 3: return rdm.randn(shape[0],shape[1], shape[2])
        else: return rdm.randn(shape[0],shape[1], shape[2], shape[3])
        
    for dims in range(2,5):
        for i in range(10):
            shape = rdm.randint(1,17,dims)            
            a1 = np.float32(get_rdm_array(shape, dims))
            a2 = np.float32(get_rdm_array(shape, dims))
            a3 = np.float32(get_rdm_array(shape, dims))
            a = np.vstack([a1,a2,a3])
            original_shape = a.shape
            a = a.reshape([int(original_shape[0]/3)] + list(original_shape[1:])+[3])
            b = np.zeros_like(a, dtype=np.float32)
            A = _h.create_from_numpy(a)
            
            _h.strided_elementwise_inplace(A, 1,'logistic')
            _h.strided_elementwise_inplace(A, 0,'tanh')
            outputs = _h.get_numpy_copy(A).reshape(original_shape)
            
            c1 = np.tanh(a1)
            c2 = 1./(1.+np.exp(a2))
            c3 = a3
            c = np.vstack([c1,c2,c3])
            
            passed = np.allclose(outputs, c)                
            assert passed