示例#1
0
 def test_make_whimsy_element_tensor_type_none_replaced_by_1(self):
   type_spec = computation_types.TensorType(tf.float32,
                                            [None, 10, None, 10, 10])
   elem = tensorflow_utils.make_whimsy_element_for_type_spec(
       type_spec, none_dim_replacement=1)
   correct_elem = np.zeros([1, 10, 1, 10, 10], np.float32)
   self.assertAllClose(elem, correct_elem)
示例#2
0
 def test_make_whimsy_element_string_tensor(self):
   type_spec = computation_types.TensorType(tf.string, [None])
   elem = tensorflow_utils.make_whimsy_element_for_type_spec(
       type_spec, none_dim_replacement=1)
   self.assertIsInstance(elem, np.ndarray)
   self.assertAllEqual(elem.shape, [1])
   self.assertEqual(elem[0], '')
示例#3
0
 def test_make_whimsy_element_StructType(self):
   tensor1 = computation_types.TensorType(tf.float32, [None, 10, None, 10, 10])
   tensor2 = computation_types.TensorType(tf.int32, [10, None, 10])
   namedtuple = computation_types.StructType([('x', tensor1), ('y', tensor2)])
   unnamedtuple = computation_types.StructType([('x', tensor1),
                                                ('y', tensor2)])
   elem = tensorflow_utils.make_whimsy_element_for_type_spec(namedtuple)
   correct_list = [
       np.zeros([0, 10, 0, 10, 10], np.float32),
       np.zeros([10, 0, 10], np.int32)
   ]
   self.assertEqual(len(elem), len(correct_list))
   for k in range(len(elem)):
     self.assertAllClose(elem[k], correct_list[k])
   unnamed_elem = tensorflow_utils.make_whimsy_element_for_type_spec(
       unnamedtuple)
   self.assertEqual(len(unnamed_elem), len(correct_list))
   for k in range(len(unnamed_elem)):
     self.assertAllClose(unnamed_elem[k], correct_list[k])
示例#4
0
 def test_make_whimsy_element_tensor_type_backed_by_tf_dimension(self):
   type_spec = computation_types.TensorType(tf.float32, [
       tf.compat.v1.Dimension(None),
       tf.compat.v1.Dimension(10),
       tf.compat.v1.Dimension(None),
       tf.compat.v1.Dimension(10),
       tf.compat.v1.Dimension(10)
   ])
   elem = tensorflow_utils.make_whimsy_element_for_type_spec(type_spec)
   correct_elem = np.zeros([0, 10, 0, 10, 10], np.float32)
   self.assertAllClose(elem, correct_elem)
示例#5
0
 def test_make_whimsy_element_TensorType(self):
   type_spec = computation_types.TensorType(tf.float32,
                                            [None, 10, None, 10, 10])
   elem = tensorflow_utils.make_whimsy_element_for_type_spec(type_spec)
   correct_elem = np.zeros([0, 10, 0, 10, 10], np.float32)
   self.assertAllClose(elem, correct_elem)
示例#6
0
 def test_make_whimsy_element_for_type_spec_raises_negative_none_dim_replacement(
     self):
   with self.assertRaisesRegex(ValueError, 'nonnegative'):
     tensorflow_utils.make_whimsy_element_for_type_spec(tf.float32, -1)
示例#7
0
 def test_make_whimsy_element_for_type_spec_raises_SequenceType(self):
   type_spec = computation_types.SequenceType(tf.float32)
   with self.assertRaisesRegex(ValueError,
                               'Cannot construct array for TFF type'):
     tensorflow_utils.make_whimsy_element_for_type_spec(type_spec)