示例#1
0
 async def _compute_intrinsic_federated_weighted_mean(self, arg):
   type_utils.check_valid_federated_weighted_mean_argument_tuple_type(
       arg.type_signature)
   zipped_arg = await self._compute_intrinsic_federated_zip_at_clients(arg)
   # TODO(b/134543154): Replace with something that produces a section of
   # plain TensorFlow code instead of constructing a lambda (so that this
   # can be executed directly on top of a plain TensorFlow-based executor).
   multiply_blk = intrinsic_utils.construct_binary_operator_with_upcast(
       zipped_arg.type_signature.member, tf.multiply)
   sum_of_products = await self._compute_intrinsic_federated_sum(
       await self._compute_intrinsic_federated_map(
           FederatedExecutorValue(
               anonymous_tuple.AnonymousTuple([
                   (None, multiply_blk.proto),
                   (None, zipped_arg.internal_representation)
               ]),
               computation_types.NamedTupleType(
                   [multiply_blk.type_signature, zipped_arg.type_signature]))))
   total_weight = await self._compute_intrinsic_federated_sum(
       FederatedExecutorValue(arg.internal_representation[1],
                              arg.type_signature[1]))
   divide_arg = await self._compute_intrinsic_federated_zip_at_server(
       await self.create_tuple(
           anonymous_tuple.AnonymousTuple([(None, sum_of_products),
                                           (None, total_weight)])))
   divide_blk = intrinsic_utils.construct_binary_operator_with_upcast(
       divide_arg.type_signature.member, tf.divide)
   return await self._compute_intrinsic_federated_apply(
       FederatedExecutorValue(
           anonymous_tuple.AnonymousTuple([
               (None, divide_blk.proto),
               (None, divide_arg.internal_representation)
           ]),
           computation_types.NamedTupleType(
               [divide_blk.type_signature, divide_arg.type_signature])))
 def test_construct_generic_raises_federated_type(self):
     bad_type = computation_types.FederatedType(
         [[tf.int32, tf.int32],
          computation_types.TensorType(tf.float32, [2])],
         placement_literals.CLIENTS)
     with self.assertRaisesRegex(TypeError,
                                 'argument that is not a two-tuple'):
         intrinsic_utils.construct_binary_operator_with_upcast(
             bad_type, tf.multiply)
 def test_raises_non_callable_op(self):
     bad_type_ref = computation_building_blocks.Reference(
         'x', [tf.float32, tf.float32])
     with self.assertRaisesRegex(TypeError, 'non-callable'):
         intrinsic_utils.apply_binary_operator_with_upcast(
             bad_type_ref, tf.constant(0))
     with self.assertRaisesRegex(TypeError, 'non-callable'):
         intrinsic_utils.construct_binary_operator_with_upcast(
             bad_type_ref, tf.constant(0))
 def test_raises_tuple_scalar_multiplied_by_nonscalar(self):
     bad_type_ref = computation_building_blocks.Reference(
         'x',
         [tf.int32, computation_types.TensorType(tf.float32, [2])])
     with self.assertRaisesRegex(TypeError, 'incompatible with upcasted'):
         intrinsic_utils.apply_binary_operator_with_upcast(
             bad_type_ref, tf.multiply)
     with self.assertRaisesRegex(TypeError, 'incompatible with upcasted'):
         intrinsic_utils.construct_binary_operator_with_upcast(
             bad_type_ref.type_signature, tf.multiply)
 def test_raises_incompatible_tuple_and_tensor(self):
     bad_type_ref = computation_building_blocks.Reference(
         'x',
         computation_types.FederatedType([[tf.int32, tf.int32], tf.float32],
                                         placement_literals.CLIENTS))
     with self.assertRaisesRegex(TypeError, 'incompatible with upcasted'):
         intrinsic_utils.apply_binary_operator_with_upcast(
             bad_type_ref, tf.multiply)
     with self.assertRaisesRegex(TypeError, 'incompatible with upcasted'):
         intrinsic_utils.construct_binary_operator_with_upcast(
             bad_type_ref.type_signature.member, tf.multiply)
 def test_construct_integer_type_signature(self):
     ref = computation_building_blocks.Reference('x', [tf.int32, tf.int32])
     multiplier = intrinsic_utils.construct_binary_operator_with_upcast(
         ref.type_signature, tf.multiply)
     self.assertEqual(
         multiplier.type_signature,
         type_constructors.binary_op(computation_types.to_type(tf.int32)))
示例#7
0
 def federated_sum(x):
     zero = intrinsic_utils.zero_for(x.type_signature.member, context_stack)
     plus_op = value_impl.ValueImpl(
         intrinsic_utils.construct_binary_operator_with_upcast(
             computation_types.NamedTupleType(
                 [x.type_signature.member, x.type_signature.member]),
             tf.add), context_stack)
     return federated_reduce([x, zero, plus_op])
 def test_construct_divide_op_named_tuple_with_scalar_type_signature(self):
     type_spec = computation_types.to_type([[('a', tf.float32),
                                             ('b', tf.float32)],
                                            tf.float32])
     multiplier = intrinsic_utils.construct_binary_operator_with_upcast(
         type_spec, tf.divide)
     expected_function_type = computation_types.FunctionType(
         type_spec, type_spec[0])
     self.assertEqual(multiplier.type_signature, expected_function_type)
 def test_construct_op_raises_on_none_operator(self):
     with self.assertRaisesRegex(TypeError, 'found non-callable'):
         intrinsic_utils.construct_binary_operator_with_upcast(
             tf.int32, None)