def test_returns_string_for_named_tuple_type_one_element(self):
   type_spec = computation_types.NamedTupleType((tf.int32,))
   compact_string = computation_types.compact_representation(type_spec)
   self.assertEqual(compact_string, '<int32>')
   formatted_string = computation_types.formatted_representation(type_spec)
   # pyformat: disable
   self.assertEqual(
       formatted_string,
       '<\n'
       '  int32\n'
       '>'
   )
 def test_returns_string_for_named_tuple_type_unnamed(self):
   type_spec = computation_types.NamedTupleType((tf.int32, tf.float32))
   # compact_string = computation_types.compact_representation(type_spec)
   # self.assertEqual(compact_string, '<int32,float32>')
   formatted_string = computation_types.formatted_representation(type_spec)
   # pyformat: disable
   self.assertEqual(
       formatted_string,
       '<\n'
       '  int32,\n'
       '  float32\n'
       '>'
   )
 def test_returns_string_for_named_tuple_type_named(self):
   type_spec = computation_types.NamedTupleType(
       (('a', tf.int32), ('b', tf.float32)))
   compact_string = computation_types.compact_representation(type_spec)
   self.assertEqual(compact_string, '<a=int32,b=float32>')
   formatted_string = computation_types.formatted_representation(type_spec)
   # pyformat: disable
   self.assertEqual(
       formatted_string,
       '<\n'
       '  a=int32,\n'
       '  b=float32\n'
       '>'
   )
 def test_returns_string_for_function_type_with_named_tuple_type_result(self):
   result = computation_types.NamedTupleType((tf.int32, tf.float32))
   type_spec = computation_types.FunctionType(tf.bool, result)
   compact_string = computation_types.compact_representation(type_spec)
   self.assertEqual(compact_string, '(bool -> <int32,float32>)')
   formatted_string = computation_types.formatted_representation(type_spec)
   # pyformat: disable
   self.assertEqual(
       formatted_string,
       '(bool -> <\n'
       '  int32,\n'
       '  float32\n'
       '>)'
   )
 def test_returns_string_for_function_type_with_named_tuple_type_parameter(
     self):
   parameter = computation_types.NamedTupleType((tf.int32, tf.float32))
   type_spec = computation_types.FunctionType(parameter, tf.bool)
   compact_string = computation_types.compact_representation(type_spec)
   self.assertEqual(compact_string, '(<int32,float32> -> bool)')
   formatted_string = computation_types.formatted_representation(type_spec)
   # pyformat: disable
   self.assertEqual(
       formatted_string,
       '(<\n'
       '  int32,\n'
       '  float32\n'
       '> -> bool)'
   )
 def test_returns_string_for_named_tuple_type_nested(self):
   type_spec_1 = computation_types.NamedTupleType((tf.int32, tf.float32))
   type_spec_2 = computation_types.NamedTupleType((type_spec_1, tf.bool))
   type_spec_3 = computation_types.NamedTupleType((type_spec_2, tf.string))
   compact_string = computation_types.compact_representation(type_spec_3)
   self.assertEqual(compact_string, '<<<int32,float32>,bool>,string>')
   formatted_string = computation_types.formatted_representation(type_spec_3)
   # pyformat: disable
   self.assertEqual(
       formatted_string,
       '<\n'
       '  <\n'
       '    <\n'
       '      int32,\n'
       '      float32\n'
       '    >,\n'
       '    bool\n'
       '  >,\n'
       '  string\n'
       '>'
   )
 def test_returns_string_for_tensor_type_float(self):
   type_spec = computation_types.TensorType(tf.float32)
   compact_string = computation_types.compact_representation(type_spec)
   self.assertEqual(compact_string, 'float32')
   formatted_string = computation_types.formatted_representation(type_spec)
   self.assertEqual(formatted_string, 'float32')
 def test_returns_string_for_sequence_type_int(self):
   type_spec = computation_types.SequenceType(tf.int32)
   compact_string = computation_types.compact_representation(type_spec)
   self.assertEqual(compact_string, 'int32*')
   formatted_string = computation_types.formatted_representation(type_spec)
   self.assertEqual(formatted_string, 'int32*')
 def test_returns_string_for_placement_type(self):
   type_spec = computation_types.PlacementType()
   compact_string = computation_types.compact_representation(type_spec)
   self.assertEqual(compact_string, 'placement')
   formatted_string = computation_types.formatted_representation(type_spec)
   self.assertEqual(formatted_string, 'placement')
Esempio n. 10
0
 def test_returns_string_for_function_type(self):
   type_spec = computation_types.FunctionType(tf.int32, tf.float32)
   compact_string = computation_types.compact_representation(type_spec)
   self.assertEqual(compact_string, '(int32 -> float32)')
   formatted_string = computation_types.formatted_representation(type_spec)
   self.assertEqual(formatted_string, '(int32 -> float32)')
Esempio n. 11
0
 def test_returns_string_for_federated_type_server(self):
   type_spec = computation_types.FederatedType(tf.int32, placements.SERVER)
   compact_string = computation_types.compact_representation(type_spec)
   self.assertEqual(compact_string, 'int32@SERVER')
   formatted_string = computation_types.formatted_representation(type_spec)
   self.assertEqual(formatted_string, 'int32@SERVER')
Esempio n. 12
0
 def test_returns_string_for_federated_type_clients(self):
   type_spec = computation_types.FederatedType(tf.int32, placements.CLIENTS)
   compact_string = computation_types.compact_representation(type_spec)
   self.assertEqual(compact_string, '{int32}@CLIENTS')
   formatted_string = computation_types.formatted_representation(type_spec)
   self.assertEqual(formatted_string, '{int32}@CLIENTS')
Esempio n. 13
0
 def test_returns_string_for_abstract_type(self):
   type_spec = computation_types.AbstractType('T')
   compact_string = computation_types.compact_representation(type_spec)
   self.assertEqual(compact_string, 'T')
   formatted_string = computation_types.formatted_representation(type_spec)
   self.assertEqual(formatted_string, 'T')
Esempio n. 14
0
 def test_raises_type_error(self):
   with self.assertRaises(TypeError):
     computation_types.compact_representation(None)
   with self.assertRaises(TypeError):
     computation_types.formatted_representation(None)