def test_returns_string_for_named_tuple_type_one_element(self): type_spec = computation_types.NamedTupleType((tf.int32,)) compact_string = computation_types.compact_representation(type_spec) self.assertEqual(compact_string, '<int32>') formatted_string = computation_types.formatted_representation(type_spec) # pyformat: disable self.assertEqual( formatted_string, '<\n' ' int32\n' '>' )
def test_returns_string_for_named_tuple_type_unnamed(self): type_spec = computation_types.NamedTupleType((tf.int32, tf.float32)) # compact_string = computation_types.compact_representation(type_spec) # self.assertEqual(compact_string, '<int32,float32>') formatted_string = computation_types.formatted_representation(type_spec) # pyformat: disable self.assertEqual( formatted_string, '<\n' ' int32,\n' ' float32\n' '>' )
def test_returns_string_for_named_tuple_type_named(self): type_spec = computation_types.NamedTupleType( (('a', tf.int32), ('b', tf.float32))) compact_string = computation_types.compact_representation(type_spec) self.assertEqual(compact_string, '<a=int32,b=float32>') formatted_string = computation_types.formatted_representation(type_spec) # pyformat: disable self.assertEqual( formatted_string, '<\n' ' a=int32,\n' ' b=float32\n' '>' )
def test_returns_string_for_function_type_with_named_tuple_type_result(self): result = computation_types.NamedTupleType((tf.int32, tf.float32)) type_spec = computation_types.FunctionType(tf.bool, result) compact_string = computation_types.compact_representation(type_spec) self.assertEqual(compact_string, '(bool -> <int32,float32>)') formatted_string = computation_types.formatted_representation(type_spec) # pyformat: disable self.assertEqual( formatted_string, '(bool -> <\n' ' int32,\n' ' float32\n' '>)' )
def test_returns_string_for_function_type_with_named_tuple_type_parameter( self): parameter = computation_types.NamedTupleType((tf.int32, tf.float32)) type_spec = computation_types.FunctionType(parameter, tf.bool) compact_string = computation_types.compact_representation(type_spec) self.assertEqual(compact_string, '(<int32,float32> -> bool)') formatted_string = computation_types.formatted_representation(type_spec) # pyformat: disable self.assertEqual( formatted_string, '(<\n' ' int32,\n' ' float32\n' '> -> bool)' )
def test_returns_string_for_named_tuple_type_nested(self): type_spec_1 = computation_types.NamedTupleType((tf.int32, tf.float32)) type_spec_2 = computation_types.NamedTupleType((type_spec_1, tf.bool)) type_spec_3 = computation_types.NamedTupleType((type_spec_2, tf.string)) compact_string = computation_types.compact_representation(type_spec_3) self.assertEqual(compact_string, '<<<int32,float32>,bool>,string>') formatted_string = computation_types.formatted_representation(type_spec_3) # pyformat: disable self.assertEqual( formatted_string, '<\n' ' <\n' ' <\n' ' int32,\n' ' float32\n' ' >,\n' ' bool\n' ' >,\n' ' string\n' '>' )
def test_returns_string_for_tensor_type_float(self): type_spec = computation_types.TensorType(tf.float32) compact_string = computation_types.compact_representation(type_spec) self.assertEqual(compact_string, 'float32') formatted_string = computation_types.formatted_representation(type_spec) self.assertEqual(formatted_string, 'float32')
def test_returns_string_for_sequence_type_int(self): type_spec = computation_types.SequenceType(tf.int32) compact_string = computation_types.compact_representation(type_spec) self.assertEqual(compact_string, 'int32*') formatted_string = computation_types.formatted_representation(type_spec) self.assertEqual(formatted_string, 'int32*')
def test_returns_string_for_placement_type(self): type_spec = computation_types.PlacementType() compact_string = computation_types.compact_representation(type_spec) self.assertEqual(compact_string, 'placement') formatted_string = computation_types.formatted_representation(type_spec) self.assertEqual(formatted_string, 'placement')
def test_returns_string_for_function_type(self): type_spec = computation_types.FunctionType(tf.int32, tf.float32) compact_string = computation_types.compact_representation(type_spec) self.assertEqual(compact_string, '(int32 -> float32)') formatted_string = computation_types.formatted_representation(type_spec) self.assertEqual(formatted_string, '(int32 -> float32)')
def test_returns_string_for_federated_type_server(self): type_spec = computation_types.FederatedType(tf.int32, placements.SERVER) compact_string = computation_types.compact_representation(type_spec) self.assertEqual(compact_string, 'int32@SERVER') formatted_string = computation_types.formatted_representation(type_spec) self.assertEqual(formatted_string, 'int32@SERVER')
def test_returns_string_for_federated_type_clients(self): type_spec = computation_types.FederatedType(tf.int32, placements.CLIENTS) compact_string = computation_types.compact_representation(type_spec) self.assertEqual(compact_string, '{int32}@CLIENTS') formatted_string = computation_types.formatted_representation(type_spec) self.assertEqual(formatted_string, '{int32}@CLIENTS')
def test_returns_string_for_abstract_type(self): type_spec = computation_types.AbstractType('T') compact_string = computation_types.compact_representation(type_spec) self.assertEqual(compact_string, 'T') formatted_string = computation_types.formatted_representation(type_spec) self.assertEqual(formatted_string, 'T')
def test_raises_type_error(self): with self.assertRaises(TypeError): computation_types.compact_representation(None) with self.assertRaises(TypeError): computation_types.formatted_representation(None)