예제 #1
0
 def test_make_dummy_element_NamedTupleType(self):
     tensor1 = computation_types.TensorType(tf.float32,
                                            [None, 10, None, 10, 10])
     tensor2 = computation_types.TensorType(tf.int32, [10, None, 10])
     namedtuple = computation_types.NamedTupleType([('x', tensor1),
                                                    ('y', tensor2)])
     unnamedtuple = computation_types.NamedTupleType([('x', tensor1),
                                                      ('y', tensor2)])
     elem = tensorflow_utils.make_dummy_element_for_type_spec(namedtuple)
     correct_list = [
         np.zeros([0, 10, 0, 10, 10], np.float32),
         np.zeros([10, 0, 10], np.int32)
     ]
     self.assertEqual(len(elem), len(correct_list))
     for k in range(len(elem)):
         self.assertAllClose(elem[k], correct_list[k])
     unnamed_elem = tensorflow_utils.make_dummy_element_for_type_spec(
         unnamedtuple)
     self.assertEqual(len(unnamed_elem), len(correct_list))
     for k in range(len(unnamed_elem)):
         self.assertAllClose(unnamed_elem[k], correct_list[k])
예제 #2
0
 def test_make_dummy_element_TensorType(self):
     type_spec = computation_types.TensorType(tf.float32,
                                              [None, 10, None, 10, 10])
     elem = tensorflow_utils.make_dummy_element_for_type_spec(type_spec)
     correct_elem = np.zeros([0, 10, 0, 10, 10], np.float32)
     self.assertAllClose(elem, correct_elem)
예제 #3
0
 def test_make_dummy_element_for_type_spec_raises_negative_none_dim_replacement(
         self):
     with self.assertRaisesRegex(ValueError, 'nonnegative'):
         tensorflow_utils.make_dummy_element_for_type_spec(tf.float32, -1)
예제 #4
0
 def test_make_dummy_element_for_type_spec_raises_SequenceType(self):
     type_spec = computation_types.SequenceType(tf.float32)
     with self.assertRaisesRegex(ValueError,
                                 'Cannot construct array for TFF type'):
         tensorflow_utils.make_dummy_element_for_type_spec(type_spec)