Esempio n. 1
0
    def federated_mean(self, value, weight):
        """Implements `federated_mean` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      weight: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        # TODO(b/113112108): Possibly relax the constraints on numeric types, and
        # inject implicit casts where appropriate. For instance, we might want to
        # allow `tf.int32` values as the input, and automatically cast them to
        # `tf.float321 before invoking the average, thus producing a floating-point
        # result.

        # TODO(b/120439632): Possibly allow the weight to be either structured or
        # non-scalar, e.g., for the case of averaging a convolutional layer, when
        # we would want to use a different weight for every filter, and where it
        # might be cumbersome for users to have to manually slice and assemble a
        # variable.

        value = value_impl.to_value(value, None, self._context_stack)
        type_utils.check_federated_value_placement(value, placements.CLIENTS,
                                                   'value to be averaged')
        if not type_utils.is_average_compatible(value.type_signature):
            raise TypeError(
                'The value type {} is not compatible with the average operator.'
                .format(value.type_signature))

        if weight is not None:
            weight = value_impl.to_value(weight, None, self._context_stack)
            type_utils.check_federated_value_placement(
                weight, placements.CLIENTS, 'weight to use in averaging')
            py_typecheck.check_type(weight.type_signature.member,
                                    computation_types.TensorType)
            if weight.type_signature.member.shape.ndims != 0:
                raise TypeError(
                    'The weight type {} is not a federated scalar.'.format(
                        weight.type_signature))
            if not (weight.type_signature.member.dtype.is_integer
                    or weight.type_signature.member.dtype.is_floating):
                raise TypeError(
                    'The weight type {} is not a federated integer or floating-point '
                    'tensor.'.format(weight.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        if weight is not None:
            weight = value_impl.ValueImpl.get_comp(weight)
        comp = computation_constructing_utils.create_federated_mean(
            value, weight)
        return value_impl.ValueImpl(comp, self._context_stack)
Esempio n. 2
0
 def test_is_average_compatible_false(self, type_spec):
     self.assertFalse(type_utils.is_average_compatible(type_spec))
Esempio n. 3
0
 def test_is_average_compatible_true(self, type_spec):
     self.assertTrue(type_utils.is_average_compatible(type_spec))