예제 #1
0
    def federated_map(self, fn, arg):
        """Implements `federated_map` as defined in `api/intrinsics.py`."""
        # TODO(b/113112108): Possibly lift the restriction that the mapped value
        # must be placed at the server or clients. Would occur after adding support
        # for placement labels in the federated types, and expanding the type
        # specification of the intrinsic this is based on to work with federated
        # values of arbitrary placement.

        arg = value_impl.to_value(arg, None, self._context_stack)
        arg = value_utils.ensure_federated_value(arg,
                                                 label='value to be mapped')

        fn = value_impl.to_value(fn,
                                 None,
                                 self._context_stack,
                                 parameter_type_hint=arg.type_signature.member)

        py_typecheck.check_type(fn, value_base.Value)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        if not type_utils.is_assignable_from(fn.type_signature.parameter,
                                             arg.type_signature.member):
            raise TypeError(
                'The mapping function expects a parameter of type {}, but member '
                'constituents of the mapped value are of incompatible type {}.'
                .format(fn.type_signature.parameter,
                        arg.type_signature.member))

        # TODO(b/144384398): Change structure to one that maps the placement type
        # to the building_block function that fits it, in a way that allows the
        # appropriate type checks.
        if arg.type_signature.placement is placements.SERVER:
            if not arg.type_signature.all_equal:
                raise TypeError(
                    'Arguments placed at {} should be equal at all locations.'.
                    format(placements.SERVER))
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            comp = building_block_factory.create_federated_apply(fn, arg)
        elif arg.type_signature.placement is placements.CLIENTS:
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            comp = building_block_factory.create_federated_map(fn, arg)
        else:
            raise TypeError(
                'The argument should be placed at {} or {}, placed at {} instead.'
                .format(placements.SERVER, placements.CLIENTS,
                        arg.type_signature.placement))

        return value_impl.ValueImpl(comp, self._context_stack)
예제 #2
0
    def federated_aggregate(self, value, zero, accumulate, merge, report):
        """Implements `federated_aggregate` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                                   'value to be aggregated')

        zero = value_impl.to_value(zero, None, self._context_stack)
        py_typecheck.check_type(zero, value_base.Value)
        accumulate = value_impl.to_value(accumulate, None, self._context_stack)
        merge = value_impl.to_value(merge, None, self._context_stack)
        report = value_impl.to_value(report, None, self._context_stack)
        for op in [accumulate, merge, report]:
            py_typecheck.check_type(op, value_base.Value)
            py_typecheck.check_type(op.type_signature,
                                    computation_types.FunctionType)

        if not type_utils.is_assignable_from(
                accumulate.type_signature.parameter[0], zero.type_signature):
            raise TypeError('Expected `zero` to be assignable to type {}, '
                            'but was of incompatible type {}.'.format(
                                accumulate.type_signature.parameter[0],
                                zero.type_signature))

        accumulate_type_expected = type_factory.reduction_op(
            accumulate.type_signature.result, value.type_signature.member)
        merge_type_expected = type_factory.reduction_op(
            accumulate.type_signature.result, accumulate.type_signature.result)
        report_type_expected = computation_types.FunctionType(
            merge.type_signature.result, report.type_signature.result)
        for op_name, op, type_expected in [
            ('accumulate', accumulate, accumulate_type_expected),
            ('merge', merge, merge_type_expected),
            ('report', report, report_type_expected)
        ]:
            if not type_utils.is_assignable_from(type_expected,
                                                 op.type_signature):
                raise TypeError(
                    'Expected parameter `{}` to be of type {}, but received {} instead.'
                    .format(op_name, type_expected, op.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        zero = value_impl.ValueImpl.get_comp(zero)
        accumulate = value_impl.ValueImpl.get_comp(accumulate)
        merge = value_impl.ValueImpl.get_comp(merge)
        report = value_impl.ValueImpl.get_comp(report)

        comp = building_block_factory.create_federated_aggregate(
            value, zero, accumulate, merge, report)
        return value_impl.ValueImpl(comp, self._context_stack)
예제 #3
0
    def federated_broadcast(self, value):
        """Implements `federated_broadcast` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        value = value_utils.ensure_federated_value(value,
                                                   placement_literals.SERVER,
                                                   'value to be broadcasted')

        if not value.type_signature.all_equal:
            raise TypeError(
                'The broadcasted value should be equal at all locations.')

        value = value_impl.ValueImpl.get_comp(value)
        comp = building_block_factory.create_federated_broadcast(value)
        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)
예제 #4
0
    def federated_reduce(self, value, zero, op):
        """Implements `federated_reduce` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      zero: As in `api/intrinsics.py`.
      op: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        # TODO(b/113112108): Since in most cases, it can be assumed that CLIENTS is
        # a non-empty collective (or else, the computation fails), specifying zero
        # at this level of the API should probably be optional. TBD.

        value = value_impl.to_value(value, None, self._context_stack)
        value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                                   'value to be reduced')

        zero = value_impl.to_value(zero, None, self._context_stack)
        py_typecheck.check_type(zero, value_base.Value)

        # TODO(b/113112108): We need a check here that zero does not have federated
        # constituents.

        op = value_impl.to_value(op, None, self._context_stack)
        py_typecheck.check_type(op, value_base.Value)
        py_typecheck.check_type(op.type_signature,
                                computation_types.FunctionType)
        op_type_expected = type_factory.reduction_op(
            zero.type_signature, value.type_signature.member)
        if not type_utils.is_assignable_from(op_type_expected,
                                             op.type_signature):
            raise TypeError('Expected an operator of type {}, got {}.'.format(
                op_type_expected, op.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        zero = value_impl.ValueImpl.get_comp(zero)
        op = value_impl.ValueImpl.get_comp(op)
        comp = building_block_factory.create_federated_reduce(value, zero, op)
        return value_impl.ValueImpl(comp, self._context_stack)
예제 #5
0
    def federated_map_all_equal(self, fn, arg):
        """Implements `federated_map` as defined in `api/intrinsic.py`.

    Implements `federated_map` as defined in `api/intrinsic.py` with an argument
    with the `all_equal` bit set.

    Args:
      fn: As in `api/intrinsics.py`.
      arg: As in `api/intrinsics.py`, with the `all_equal` bit set.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        # TODO(b/113112108): Possibly lift the restriction that the mapped value
        # must be placed at the clients after adding support for placement labels
        # in the federated types, and expanding the type specification of the
        # intrinsic this is based on to work with federated values of arbitrary
        # placement.
        arg = value_impl.to_value(arg, None, self._context_stack)
        arg = value_utils.ensure_federated_value(arg, placements.CLIENTS,
                                                 'value to be mapped')

        # TODO(b/113112108): Add support for polymorphic templates auto-instantiated
        # here based on the actual type of the argument.
        fn = value_impl.to_value(fn, None, self._context_stack)

        py_typecheck.check_type(fn, value_base.Value)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        if not type_utils.is_assignable_from(fn.type_signature.parameter,
                                             arg.type_signature.member):
            raise TypeError(
                'The mapping function expects a parameter of type {}, but member '
                'constituents of the mapped value are of incompatible type {}.'
                .format(fn.type_signature.parameter,
                        arg.type_signature.member))

        fn = value_impl.ValueImpl.get_comp(fn)
        arg = value_impl.ValueImpl.get_comp(arg)
        comp = building_block_factory.create_federated_map_all_equal(fn, arg)
        return value_impl.ValueImpl(comp, self._context_stack)
예제 #6
0
  def federated_collect(self, value):
    """Implements `federated_collect` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    value = value_impl.to_value(value, None, self._context_stack)
    value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                               'value to be collected')

    value = value_impl.ValueImpl.get_comp(value)
    comp = building_block_factory.create_federated_collect(value)
    return value_impl.ValueImpl(comp, self._context_stack)
예제 #7
0
 def federated_secure_sum(self, value, bitwidth):
     """Implements `federated_secure_sum` as defined in `api/intrinsics.py`."""
     value = value_impl.to_value(value, None, self._context_stack)
     value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                                'value to be summed')
     type_utils.check_is_structure_of_integers(value.type_signature)
     bitwidth = value_impl.to_value(bitwidth, None, self._context_stack)
     value_member_ty = value.type_signature.member
     bitwidth_ty = bitwidth.type_signature
     if not type_utils.are_equivalent_types(value_member_ty, bitwidth_ty):
         raise TypeError(
             'Expected `federated_secure_sum` parameters `value` and `bitwidth` '
             'to have the same structure. Found `value` of `{}` and `bitwidth` of `{}`'
             .format(value_member_ty, bitwidth_ty))
     value = value_impl.ValueImpl.get_comp(value)
     bitwidth = value_impl.ValueImpl.get_comp(bitwidth)
     comp = building_block_factory.create_federated_secure_sum(
         value, bitwidth)
     return value_impl.ValueImpl(comp, self._context_stack)
예제 #8
0
 def federated_secure_sum(self, value, bitwidth):
   """Implements `federated_secure_sum` as defined in `api/intrinsics.py`."""
   value = value_impl.to_value(value, None, self._context_stack)
   value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                              'value to be summed')
   type_utils.check_is_structure_of_integers(value.type_signature)
   bitwidth = value_impl.to_value(bitwidth, None, self._context_stack)
   value_member_type = value.type_signature.member
   bitwidth_type = bitwidth.type_signature
   if not type_utils.is_valid_bitwidth_type_for_value_type(
       bitwidth_type, value_member_type):
     raise TypeError(
         'Expected `federated_secure_sum` parameter `bitwidth` to match '
         'the structure of `value`, with one integer bitwidth per tensor in '
         '`value`. Found `value` of `{}` and `bitwidth` of `{}`.'.format(
             value_member_type, bitwidth_type))
   value = value_impl.ValueImpl.get_comp(value)
   bitwidth = value_impl.ValueImpl.get_comp(bitwidth)
   comp = building_block_factory.create_federated_secure_sum(value, bitwidth)
   return value_impl.ValueImpl(comp, self._context_stack)
예제 #9
0
  def federated_reduce(self, value, zero, op):
    """Implements `federated_reduce` as defined in `api/intrinsics.py`."""
    value = value_impl.to_value(value, None, self._context_stack)
    value = value_utils.ensure_federated_value(value,
                                               placement_literals.CLIENTS,
                                               'value to be reduced')

    zero = value_impl.to_value(zero, None, self._context_stack)
    if type_analysis.contains_federated_types(zero.type_signature):
      raise TypeError('`zero` may not contain a federated type, found type:\n' +
                      str(zero.type_signature))

    op = value_impl.to_value(
        op,
        None,
        self._context_stack,
        parameter_type_hint=computation_types.StructType(
            [zero.type_signature, value.type_signature.member]))
    op.type_signature.check_function()
    if not op.type_signature.result.is_assignable_from(zero.type_signature):
      raise TypeError(
          '`zero` must be assignable to the result type from `op`:\n',
          computation_types.type_mismatch_error_message(
              zero.type_signature, op.type_signature.result,
              computation_types.TypeRelation.ASSIGNABLE))
    op_type_expected = type_factory.reduction_op(op.type_signature.result,
                                                 value.type_signature.member)
    if not op_type_expected.is_assignable_from(op.type_signature):
      raise TypeError('Expected an operator of type {}, got {}.'.format(
          op_type_expected, op.type_signature))

    value = value_impl.ValueImpl.get_comp(value)
    zero = value_impl.ValueImpl.get_comp(zero)
    op = value_impl.ValueImpl.get_comp(op)
    comp = building_block_factory.create_federated_reduce(value, zero, op)
    comp = self._bind_comp_as_reference(comp)
    return value_impl.ValueImpl(comp, self._context_stack)
예제 #10
0
  def federated_sum(self, value):
    """Implements `federated_sum` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    value = value_impl.to_value(value, None, self._context_stack)
    value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                               'value to be summed')

    if not type_utils.is_sum_compatible(value.type_signature):
      raise TypeError(
          'The value type {} is not compatible with the sum operator.'.format(
              value.type_signature))

    value = value_impl.ValueImpl.get_comp(value)
    comp = building_block_factory.create_federated_sum(value)
    return value_impl.ValueImpl(comp, self._context_stack)
예제 #11
0
 def _(x):
   x = value_impl.to_value(x, None, _context_stack)
   with self.assertRaises(TypeError):
     value_utils.ensure_federated_value(x)
   return x
예제 #12
0
 def _(x):
   x = value_impl.to_value(x, None, _context_stack)
   value_utils.ensure_federated_value(x)
   return x
예제 #13
0
 def _(x):
   x = value_impl.to_value(x, None, _context_stack)
   value_utils.ensure_federated_value(x, placement_literals.CLIENTS)
   return x