def test_make_whimsy_element_tensor_type_none_replaced_by_1(self): type_spec = computation_types.TensorType(tf.float32, [None, 10, None, 10, 10]) elem = tensorflow_utils.make_whimsy_element_for_type_spec( type_spec, none_dim_replacement=1) correct_elem = np.zeros([1, 10, 1, 10, 10], np.float32) self.assertAllClose(elem, correct_elem)
def test_make_whimsy_element_string_tensor(self): type_spec = computation_types.TensorType(tf.string, [None]) elem = tensorflow_utils.make_whimsy_element_for_type_spec( type_spec, none_dim_replacement=1) self.assertIsInstance(elem, np.ndarray) self.assertAllEqual(elem.shape, [1]) self.assertEqual(elem[0], '')
def test_make_whimsy_element_StructType(self): tensor1 = computation_types.TensorType(tf.float32, [None, 10, None, 10, 10]) tensor2 = computation_types.TensorType(tf.int32, [10, None, 10]) namedtuple = computation_types.StructType([('x', tensor1), ('y', tensor2)]) unnamedtuple = computation_types.StructType([('x', tensor1), ('y', tensor2)]) elem = tensorflow_utils.make_whimsy_element_for_type_spec(namedtuple) correct_list = [ np.zeros([0, 10, 0, 10, 10], np.float32), np.zeros([10, 0, 10], np.int32) ] self.assertEqual(len(elem), len(correct_list)) for k in range(len(elem)): self.assertAllClose(elem[k], correct_list[k]) unnamed_elem = tensorflow_utils.make_whimsy_element_for_type_spec( unnamedtuple) self.assertEqual(len(unnamed_elem), len(correct_list)) for k in range(len(unnamed_elem)): self.assertAllClose(unnamed_elem[k], correct_list[k])
def test_make_whimsy_element_tensor_type_backed_by_tf_dimension(self): type_spec = computation_types.TensorType(tf.float32, [ tf.compat.v1.Dimension(None), tf.compat.v1.Dimension(10), tf.compat.v1.Dimension(None), tf.compat.v1.Dimension(10), tf.compat.v1.Dimension(10) ]) elem = tensorflow_utils.make_whimsy_element_for_type_spec(type_spec) correct_elem = np.zeros([0, 10, 0, 10, 10], np.float32) self.assertAllClose(elem, correct_elem)
def test_make_whimsy_element_TensorType(self): type_spec = computation_types.TensorType(tf.float32, [None, 10, None, 10, 10]) elem = tensorflow_utils.make_whimsy_element_for_type_spec(type_spec) correct_elem = np.zeros([0, 10, 0, 10, 10], np.float32) self.assertAllClose(elem, correct_elem)
def test_make_whimsy_element_for_type_spec_raises_negative_none_dim_replacement( self): with self.assertRaisesRegex(ValueError, 'nonnegative'): tensorflow_utils.make_whimsy_element_for_type_spec(tf.float32, -1)
def test_make_whimsy_element_for_type_spec_raises_SequenceType(self): type_spec = computation_types.SequenceType(tf.float32) with self.assertRaisesRegex(ValueError, 'Cannot construct array for TFF type'): tensorflow_utils.make_whimsy_element_for_type_spec(type_spec)