Ejemplo n.º 1
0
 def testGain(self):
   shape = (10, 10)
   for dtype in [dtypes.float32, dtypes.float64]:
     init_default = init_ops_v2.Identity()
     init_custom = init_ops_v2.Identity(gain=0.9)
     with test_util.use_gpu():
       self.assertAllClose(self.evaluate(init_default(shape, dtype=dtype)),
                           np.eye(*shape))
     with test_util.use_gpu():
       self.assertAllClose(self.evaluate(init_custom(shape, dtype=dtype)),
                           np.eye(*shape) * 0.9)
Ejemplo n.º 2
0
    def testRange(self):
        with self.assertRaises(ValueError):
            shape = (3, 4, 5)
            self._range_test(init_ops_v2.Identity(),
                             shape=shape,
                             target_mean=1. / shape[0],
                             target_max=1.)

        shape = (3, 3)
        self._range_test(init_ops_v2.Identity(),
                         shape=shape,
                         target_mean=1. / shape[0],
                         target_max=1.)
Ejemplo n.º 3
0
 def testPartition(self):
     init = init_ops_v2.Identity()
     with self.assertRaisesWithLiteralMatch(
             ValueError,
             r"Identity initializer doesn't support partition-related arguments"
     ):
         init((4, 2), dtype=dtypes.float32, partition_shape=(2, 2))
Ejemplo n.º 4
0
 def create_vars():
   if not collection:
     identity = init_ops_v2.Identity()
     v1 = variables.Variable([[1., 0.], [0., 1.]], dtype=dtypes.float32)
     v2 = variables.Variable(lambda: identity((2, 2), dtypes.float32))
     v3 = variables.Variable(
         lambda: identity((2, 2), dtypes.float32),
         dtype=dtypes.float32,
         shape=(2, 2))
     collection.extend([v1, v2, v3])
Ejemplo n.º 5
0
 def testNonSquare(self):
   init = init_ops_v2.Identity()
   shape = (10, 5)
   with test_util.use_gpu():
     self.assertAllClose(self.evaluate(init(shape)), np.eye(*shape))
Ejemplo n.º 6
0
 def testInvalidShape(self):
   init = init_ops_v2.Identity()
   with test_util.use_gpu():
     self.assertRaises(ValueError, init, shape=[5, 7, 7])
     self.assertRaises(ValueError, init, shape=[5])
     self.assertRaises(ValueError, init, shape=[])
Ejemplo n.º 7
0
 def testInvalidDataType(self):
   init = init_ops_v2.Identity()
   self.assertRaises(ValueError, init, shape=[10, 5], dtype=dtypes.int32)