Beispiel #1
0
    def test_ops(self):
        x = K.variable(np.random.rand(8, 12))
        y = K.variable(np.random.rand(12, 25))
        z = K.placeholder((25, 18, 13))
        w = K.placeholder((18, 18))

        # ====== dot ====== #
        t = K.dot(x, y)
        self.assertEquals(K.get_shape(t), (8, 25))
        self.assertEquals(K.get_shape(t), K.eval(t).shape)
        t = K.dot(t, K.dimshuffle(z, (1, 0, 2)))
        self.assertEquals(K.get_shape(t), (8, 18, 13))

        # ====== transpose ====== #
        self.assertEquals(K.get_shape(K.transpose(z)), (13, 18, 25))
        self.assertEquals(K.get_shape(K.transpose(t, axes=(2, 0, 1))),
                          (13, 8, 18))

        # ====== eye ====== #
        self.assertEquals(K.get_shape(K.eye(5)), K.eval(K.eye(5)).shape)
        # ====== diag ====== #
        self.assertEquals(K.get_shape(K.diag(w)), (18, ))
        # self.assertEquals(K.get_shape(K.diag(x)),
        # K.eval(K.diag(y)).shape)
        self.assertEquals(K.get_shape(K.square(x)), K.eval(K.square(x)).shape)
        self.assertEquals(K.get_shape(K.abs(x)), K.eval(K.abs(x)).shape)
        self.assertEquals(K.get_shape(K.sqrt(x)), K.eval(K.sqrt(x)).shape)
        self.assertEquals(K.get_shape(K.exp(x)), K.eval(K.exp(x)).shape)
        self.assertEquals(K.get_shape(K.log(x)), K.eval(K.log(x)).shape)
        self.assertEquals(K.get_shape(K.round(x)), K.eval(K.round(x)).shape)
        self.assertEquals(K.get_shape(K.pow(x, 2)), K.eval(K.pow(x, 2)).shape)
        self.assertEquals(K.get_shape(K.clip(x, -1, 1)),
                          K.eval(K.clip(x, -1, 1)).shape)
        self.assertEquals(K.get_shape(K.inv(x)), K.eval(K.inv(x)).shape)
Beispiel #2
0
    def test_basic_ops_value(self):
        np.random.seed(12082518)
        x = K.variable(np.random.randn(8, 8))
        y = K.variable(np.random.randn(8, 8))
        z = K.variable(np.random.randint(0, 2, size=(8, 8)), dtype=np.bool)
        w = K.variable(np.random.randint(0, 2, size=(8, 8)), dtype=np.bool)

        self.assertEqual(round(np.sum(K.eval(K.relu(x, alpha=0.12))) * 10000),
                         276733)
        self.assertEqual(round(np.sum(K.eval(K.elu(x, alpha=0.12))) * 10000),
                         289202)
        self.assertEqual(np.sum(K.eval(K.softmax(x))), 8.0)
        self.assertEqual(round(np.sum(K.eval(K.softplus(x))) * 10000), 554564)
        self.assertEqual(round(np.sum(K.eval(K.softsign(x))) * 100000), 211582)
        self.assertEqual(round(np.sum(K.eval(K.sigmoid(x))) * 10000), 330427)
        self.assertEqual(round(np.sum(K.eval(K.hard_sigmoid(x))) * 10000),
                         330836)
        self.assertEqual(round(np.sum(K.eval(K.tanh(x))) * 100000), 290165)
        self.assertEqual(round(np.sum(K.eval(K.square(x))) * 10000), 744492)
        self.assertEqual(round(np.sum(K.eval(K.sqrt(x))) * 10000), 300212)
        self.assertEqual(round(np.sum(K.eval(K.abs(x))) * 10000), 559979)
        self.assertEqual(np.sum(K.eval(K.sign(x))), 6.0)
        self.assertEqual(round(np.sum(K.eval(K.inv(x))) * 1000), 495838)
        self.assertEqual(round(np.sum(K.eval(K.exp(x))) * 1000), 122062)
        self.assertEqual(round(np.sum(K.eval(K.log(K.abs(x)))) * 10000),
                         -344491)
        self.assertEqual(np.sum(K.eval(K.round(x))), 5.0)
        self.assertEqual(round(np.sum(K.eval(K.pow(x, 8))) * 100), 398153)
        self.assertEqual(
            round(np.sum(K.eval(K.clip(x, -0.12, 0.12))) * 1000000), 620529)
        # TODO: pygpu (libgpuarray) still not support diag
        # self.assertEqual(round(np.sum(K.eval(K.diag(x))) * 100000), 325289)
        self.assertEqual(np.sum(K.eval(K.eye(12, 8))), 8.0)

        self.assertEqual(np.sum(K.eval(K.eq(z, w))), 38)
        self.assertEqual(np.sum(K.eval(K.neq(z, w))), 26)
        self.assertEqual(np.sum(K.eval(K.gt(x, y))), 33)
        self.assertEqual(np.sum(K.eval(K.ge(x, y))), 33)
        self.assertEqual(np.sum(K.eval(K.lt(x, y))), 31)
        self.assertEqual(np.sum(K.eval(K.le(x, y))), 31)
        self.assertEqual(round(np.sum(K.eval(K.switch(z, x, y))) * 100000),
                         139884)