def test_make_dummy_element_NamedTupleType(self): tensor1 = computation_types.TensorType(tf.float32, [None, 10, None, 10, 10]) tensor2 = computation_types.TensorType(tf.int32, [10, None, 10]) namedtuple = computation_types.NamedTupleType([('x', tensor1), ('y', tensor2)]) unnamedtuple = computation_types.NamedTupleType([('x', tensor1), ('y', tensor2)]) elem = tensorflow_utils.make_dummy_element_for_type_spec(namedtuple) correct_list = [ np.zeros([0, 10, 0, 10, 10], np.float32), np.zeros([10, 0, 10], np.int32) ] self.assertEqual(len(elem), len(correct_list)) for k in range(len(elem)): self.assertAllClose(elem[k], correct_list[k]) unnamed_elem = tensorflow_utils.make_dummy_element_for_type_spec( unnamedtuple) self.assertEqual(len(unnamed_elem), len(correct_list)) for k in range(len(unnamed_elem)): self.assertAllClose(unnamed_elem[k], correct_list[k])
def test_make_dummy_element_TensorType(self): type_spec = computation_types.TensorType(tf.float32, [None, 10, None, 10, 10]) elem = tensorflow_utils.make_dummy_element_for_type_spec(type_spec) correct_elem = np.zeros([0, 10, 0, 10, 10], np.float32) self.assertAllClose(elem, correct_elem)
def test_make_dummy_element_for_type_spec_raises_negative_none_dim_replacement( self): with self.assertRaisesRegex(ValueError, 'nonnegative'): tensorflow_utils.make_dummy_element_for_type_spec(tf.float32, -1)
def test_make_dummy_element_for_type_spec_raises_SequenceType(self): type_spec = computation_types.SequenceType(tf.float32) with self.assertRaisesRegex(ValueError, 'Cannot construct array for TFF type'): tensorflow_utils.make_dummy_element_for_type_spec(type_spec)