def federated_mean(self, value, weight): """Implements `federated_mean` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. weight: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ # TODO(b/113112108): Possibly relax the constraints on numeric types, and # inject implicit casts where appropriate. For instance, we might want to # allow `tf.int32` values as the input, and automatically cast them to # `tf.float321 before invoking the average, thus producing a floating-point # result. # TODO(b/120439632): Possibly allow the weight to be either structured or # non-scalar, e.g., for the case of averaging a convolutional layer, when # we would want to use a different weight for every filter, and where it # might be cumbersome for users to have to manually slice and assemble a # variable. value = value_impl.to_value(value, None, self._context_stack) type_utils.check_federated_value_placement(value, placements.CLIENTS, 'value to be averaged') if not type_utils.is_average_compatible(value.type_signature): raise TypeError( 'The value type {} is not compatible with the average operator.' .format(value.type_signature)) if weight is not None: weight = value_impl.to_value(weight, None, self._context_stack) type_utils.check_federated_value_placement( weight, placements.CLIENTS, 'weight to use in averaging') py_typecheck.check_type(weight.type_signature.member, computation_types.TensorType) if weight.type_signature.member.shape.ndims != 0: raise TypeError( 'The weight type {} is not a federated scalar.'.format( weight.type_signature)) if not (weight.type_signature.member.dtype.is_integer or weight.type_signature.member.dtype.is_floating): raise TypeError( 'The weight type {} is not a federated integer or floating-point ' 'tensor.'.format(weight.type_signature)) value = value_impl.ValueImpl.get_comp(value) if weight is not None: weight = value_impl.ValueImpl.get_comp(weight) comp = computation_constructing_utils.create_federated_mean( value, weight) return value_impl.ValueImpl(comp, self._context_stack)
def test_is_average_compatible_false(self, type_spec): self.assertFalse(type_utils.is_average_compatible(type_spec))
def test_is_average_compatible_true(self, type_spec): self.assertTrue(type_utils.is_average_compatible(type_spec))