Example #1
0
 def test_getattr_comp_construction(self, placement):
     federated_value = value_impl.to_value(
         computation_building_blocks.Reference(
             'test',
             computation_types.FederatedType([('a', tf.int32),
                                              ('b', tf.bool)], placement,
                                             True)), None,
         context_stack_impl.context_stack)
     get_a_comp = computation_constructing_utils.construct_federated_getattr_comp(
         federated_value, 'a')
     self.assertEqual(str(get_a_comp), '(x -> x.a)')
     get_b_comp = computation_constructing_utils.construct_federated_getattr_comp(
         federated_value, 'b')
     self.assertEqual(str(get_b_comp), '(x -> x.b)')
     non_federated_arg = value_impl.to_value(
         computation_building_blocks.Reference(
             'test',
             computation_types.NamedTupleType([('a', tf.int32),
                                               ('b', tf.bool)])), None,
         context_stack_impl.context_stack)
     with self.assertRaises(TypeError):
         _ = computation_constructing_utils.construct_federated_getattr_comp(
             non_federated_arg, 'a')
     with self.assertRaisesRegexp(ValueError, 'has no element of name c'):
         _ = computation_constructing_utils.construct_federated_getattr_comp(
             federated_value, 'c')
Example #2
0
 def test_federated_getattr_comp_fails_value(self):
     x = computation_building_blocks.Reference(
         'x', computation_types.to_type([('x', tf.int32)]))
     with self.assertRaises(TypeError):
         computation_constructing_utils.construct_federated_getattr_comp(
             value_impl.to_value(x), 'x')