예제 #1
0
 def _(x):
   x = value_impl.to_value(x, None, _context_stack)
   with self.assertRaises(TypeError):
     value_utils.ensure_federated_value(x)
   return x
 def test_federated_getattr_comp_fails_value(self):
     x = computation_building_blocks.Reference(
         'x', computation_types.to_type([('x', tf.int32)]))
     with self.assertRaises(TypeError):
         computation_constructing_utils.construct_federated_getattr_comp(
             value_impl.to_value(x), 'x')
예제 #3
0
 def test_slicing_fails_non_namedtuple(self):
   v = value_impl.to_value(
       np.ones([10, 10, 10], dtype=np.float32), None,
       context_stack_impl.context_stack)
   with self.assertRaisesRegex(TypeError, 'only supported for named tuples'):
     _ = v[:1]
예제 #4
0
 def _(x):
   x = value_impl.to_value(x, None, _context_stack)
   value_utils.ensure_federated_value(x, placements.CLIENTS)
   return x
예제 #5
0
 def test_to_value_with_np_bool(self):
   value = value_impl.to_value(
       np.bool(1.0), tf.bool, context_stack_impl.context_stack)
   self.assertIsInstance(value, value_base.Value)
   self.assertEqual(str(value.type_signature), 'bool')
예제 #6
0
 def test_to_value_with_empty_list_of_ints(self):
   value = value_impl.to_value([], computation_types.SequenceType(tf.int32),
                               context_stack_impl.context_stack)
   self.assertIsInstance(value, value_base.Value)
   self.assertEqual(str(value.type_signature), 'int32*')
예제 #7
0
def zero_or_one_arg_fn_to_building_block(
    fn,
    parameter_name: Optional[str],
    parameter_type: Optional[computation_types.Type],
    context_stack: context_stack_base.ContextStack,
    suggested_name: Optional[str] = None,
) -> Tuple[building_blocks.ComputationBuildingBlock, computation_types.Type]:
    """Converts a zero- or one-argument `fn` into a computation building block.

  Args:
    fn: A function with 0 or 1 arguments that contains orchestration logic,
      i.e., that expects zero or one `values_base.Value` and returns a result
      convertible to the same.
    parameter_name: The name of the parameter, or `None` if there is't any.
    parameter_type: The `tff.Type` of the parameter, or `None` if there's none.
    context_stack: The context stack to use.
    suggested_name: The optional suggested name to use for the federated context
      that will be used to serialize this function's body (ideally the name of
      the underlying Python function). It might be modified to avoid conflicts.

  Returns:
    A tuple of `(building_blocks.ComputationBuildingBlock,
    computation_types.Type)`, where the first element contains the logic from
    `fn`, and the second element contains potentially annotated type information
    for the result of `fn`.

  Raises:
    ValueError: if `fn` is incompatible with `parameter_type`.
  """
    py_typecheck.check_callable(fn)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if suggested_name is not None:
        py_typecheck.check_type(suggested_name, str)
    if isinstance(context_stack.current,
                  federated_computation_context.FederatedComputationContext):
        parent_context = context_stack.current
    else:
        parent_context = None
    context = federated_computation_context.FederatedComputationContext(
        context_stack, suggested_name=suggested_name, parent=parent_context)
    if parameter_name is not None:
        py_typecheck.check_type(parameter_name, str)
        parameter_name = '{}_{}'.format(context.name, str(parameter_name))
    with context_stack.install(context):
        if parameter_type is not None:
            result = fn(
                value_impl.ValueImpl(
                    building_blocks.Reference(parameter_name, parameter_type),
                    context_stack))
        else:
            result = fn()
        if result is None:
            raise ValueError(
                'The function defined on line {} of file {} has returned a '
                '`NoneType`, but all TFF functions must return some non-`None` '
                'value.'.format(fn.__code__.co_firstlineno,
                                fn.__code__.co_filename))
        annotated_result_type = type_conversions.infer_type(result)
        result = value_impl.to_value(result, annotated_result_type,
                                     context_stack)
        result_comp = value_impl.ValueImpl.get_comp(result)
        symbols_bound_in_context = context_stack.current.symbol_bindings
        if symbols_bound_in_context:
            result_comp = building_blocks.Block(
                local_symbols=symbols_bound_in_context, result=result_comp)
        annotated_type = computation_types.FunctionType(
            parameter_type, annotated_result_type)
        return building_blocks.Lambda(parameter_name, parameter_type,
                                      result_comp), annotated_type
예제 #8
0
 def test_to_value_with_np_int32(self):
   value = value_impl.to_value(
       np.int32(1), tf.int32, context_stack_impl.context_stack)
   self.assertIsInstance(value, value_base.Value)
   self.assertEqual(str(value.type_signature), 'int32')
예제 #9
0
 def _to_value(val):
     return value_impl.to_value(val, None,
                                context_stack_impl.context_stack)
예제 #10
0
def flatten_first_index(apply_func, type_to_add, context_stack):
  """Returns a value `(arg -> APPEND(apply_func(arg[0]), arg[1]))`.

  In the above, `APPEND(a,b)` refers to appending element b to tuple a.

  Constructs a Value of a TFF functional type that:

  1. Takes as argument a 2-element tuple `(x, y)` of TFF type
     `[apply_func.type_signature.parameter, type_to_add]`.

  2. Transforms the 1st element `x` of this 2-tuple by applying `apply_func`,
     producing a result `z` that must be a TFF tuple (e.g, as a result of
     flattening `x`).

  3. Leaves the 2nd element `y` of the argument 2-tuple unchanged.

  4. Returns the result of appending the unchanged `y` at the end of the
     tuple `z` returned by `apply_func`.

  Args:
    apply_func: TFF `Value` of type_signature `FunctionType`, a function taking
      TFF `Value`s to `Value`s of type `NamedTupleType`.
    type_to_add: 2-tuple specifying name and TFF type of arg[1]. Name can be
      `None` or `string`.
    context_stack: The context stack to use, as in `impl.value_impl.to_value`.

  Returns:
    TFF `Value` of `FunctionType`, taking 2-tuples to N-tuples, which calls
      `apply_func` on the first index of its argument, appends the second
      index to the resulting (N-1)-tuple, then returns the N-tuple thus created.
  """
  py_typecheck.check_type(apply_func, value_base.Value)
  py_typecheck.check_type(apply_func.type_signature,
                          computation_types.FunctionType)
  py_typecheck.check_type(apply_func.type_signature.result,
                          computation_types.NamedTupleType)
  py_typecheck.check_type(type_to_add, tuple)
  if len(type_to_add) != 2:
    raise ValueError('Please pass a 2-tuple as type_to_add to '
                     'flatten_first_index, with first index name or None '
                     'and second index instance of `computation_types.Type` '
                     'or something convertible to one by '
                     '`computationtypes.to_type`.')
  prev_param_type = apply_func.type_signature.parameter
  inputs = value_impl.to_value(
      computation_building_blocks.Reference(
          'inputs',
          computation_types.NamedTupleType([prev_param_type, type_to_add])),
      None, context_stack)
  intermediate = apply_func(inputs[0])
  full_type_spec = anonymous_tuple.to_elements(
      apply_func.type_signature.result) + [type_to_add]
  named_values = [
      (full_type_spec[k][0], intermediate[k]) for k in range(len(intermediate))
  ] + [(full_type_spec[-1][0], inputs[1])]
  new_elements = value_impl.to_value(
      anonymous_tuple.AnonymousTuple(named_values),
      type_spec=full_type_spec,
      context_stack=context_stack)
  return value_impl.to_value(
      computation_building_blocks.Lambda(
          'inputs', inputs.type_signature,
          value_impl.ValueImpl.get_comp(new_elements)), None, context_stack)
예제 #11
0
 def test_to_value_sequence_in_tuple_with_type(self):
     expected_type = computation_types.StructWithPythonType(
         [computation_types.SequenceType(tf.int32)], tuple)
     value = value_impl.to_value(([1, 2, 3], ), expected_type,
                                 context_stack_impl.context_stack)
     value.type_signature.check_identical_to(expected_type)
예제 #12
0
 def federated_mean(arg):
     one = value_impl.ValueImpl(
         computation_constructing_utils.create_generic_constant(
             arg.type_signature, 1), context_stack)
     arg = value_impl.to_value([arg, one], None, context_stack)
     return federated_weighted_mean(arg)
def zero_or_one_arg_fn_to_building_block(fn,
                                         parameter_name,
                                         parameter_type,
                                         context_stack,
                                         suggested_name=None):
    """Converts a zero- or one-argument `fn` into a computation building block.

  Args:
    fn: A function with 0 or 1 arguments that contains orchestration logic,
      i.e., that expects zero or one `values_base.Value` and returns a result
      convertible to the same.
    parameter_name: The name of the parameter, or `None` if there is't any.
    parameter_type: The TFF type of the parameter, or `None` if there's none.
    context_stack: The context stack to use.
    suggested_name: The optional suggested name to use for the federated context
      that will be used to serialize this function's body (ideally the name of
      the underlying Python function). It might be modified to avoid conflicts.
      If not `None`, it must be a string.

  Returns:
    An instance of `building_blocks.ComputationBuildingBlock` that
    contains the logic from `fn`.

  Raises:
    ValueError: if `fn` is incompatible with `parameter_type`.
  """
    py_typecheck.check_callable(fn)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if suggested_name is not None:
        py_typecheck.check_type(suggested_name, six.string_types)
    parameter_type = computation_types.to_type(parameter_type)
    if isinstance(context_stack.current,
                  federated_computation_context.FederatedComputationContext):
        parent_context = context_stack.current
    else:
        parent_context = None
    context = federated_computation_context.FederatedComputationContext(
        context_stack, suggested_name=suggested_name, parent=parent_context)
    if parameter_name is not None:
        py_typecheck.check_type(parameter_name, six.string_types)
        parameter_name = '{}_{}'.format(context.name, str(parameter_name))
    with context_stack.install(context):
        if parameter_type is not None:
            result = fn(
                value_impl.ValueImpl(
                    building_blocks.Reference(parameter_name, parameter_type),
                    context_stack))
        else:
            result = fn()
        if result is None:
            raise ValueError(
                'The function defined on line {} of file {} has returned a '
                '`NoneType`, but all TFF functions must return some non-`None` '
                'value.'.format(fn.__code__.co_firstlineno,
                                fn.__code__.co_filename))
        result = value_impl.to_value(result, None, context_stack)
        result_comp = value_impl.ValueImpl.get_comp(result)
        if parameter_type is None:
            return result_comp
        else:
            return building_blocks.Lambda(parameter_name, parameter_type,
                                          result_comp)
예제 #14
0
 def test_to_value_for_computations(self):
   val = value_impl.to_value(
       computations.tf_computation(lambda: tf.constant(10)), None,
       context_stack_impl.context_stack)
   self.assertIsInstance(val, value_base.Value)
   self.assertEqual(str(val.type_signature), '( -> int32)')
예제 #15
0
    def federated_map(self, fn, arg):
        """Implements `federated_map` as defined in `api/intrinsics.py`.

    Args:
      fn: As in `api/intrinsics.py`.
      arg: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        # TODO(b/113112108): Possibly lift the restriction that the mapped value
        # must be placed at the server or clients. Would occur after adding support
        # for placement labels in the federated types, and expanding the type
        # specification of the intrinsic this is based on to work with federated
        # values of arbitrary placement.

        arg = value_impl.to_value(arg, None, self._context_stack)
        if isinstance(arg.type_signature, computation_types.NamedTupleType):
            if len(anonymous_tuple.to_elements(arg.type_signature)) >= 2:
                # We've been passed a value which the user expects to be zipped.
                named_type_signatures = anonymous_tuple.to_elements(
                    arg.type_signature)
                _, first_type_signature = named_type_signatures[0]
                for _, type_signature in named_type_signatures:
                    py_typecheck.check_type(type_signature,
                                            computation_types.FederatedType)
                    if type_signature.placement is not first_type_signature.placement:
                        raise TypeError(
                            'You cannot apply federated_map on nested values with mixed '
                            'placements (was given a nested value of type {}).'
                            .format(arg.type_signature))
                arg = self.federated_zip(arg)

        py_typecheck.check_type(arg.type_signature,
                                computation_types.FederatedType)

        # TODO(b/113112108): Add support for polymorphic templates auto-instantiated
        # here based on the actual type of the argument.
        fn = value_impl.to_value(fn, None, self._context_stack)

        py_typecheck.check_type(fn, value_base.Value)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        if not type_utils.is_assignable_from(fn.type_signature.parameter,
                                             arg.type_signature.member):
            raise TypeError(
                'The mapping function expects a parameter of type {}, but member '
                'constituents of the mapped value are of incompatible type {}.'
                .format(fn.type_signature.parameter,
                        arg.type_signature.member))

        # TODO(b/144384398): Change structure to one that maps the placement type
        # to the building_block function that fits it, in a way that allows the
        # appropriate type checks.
        if arg.type_signature.placement is placements.SERVER:
            if not arg.type_signature.all_equal:
                raise TypeError(
                    'Arguments placed at {} should be equal at all locations.'.
                    format(placements.SERVER))
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            comp = building_block_factory.create_federated_apply(fn, arg)
        elif arg.type_signature.placement is placements.CLIENTS:
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            comp = building_block_factory.create_federated_map(fn, arg)
        else:
            raise TypeError(
                'The argument should have placement {} or {}, found {} instead.'
                .format(placements.SERVER, placements.CLIENTS,
                        arg.type_signature.placement))

        return value_impl.ValueImpl(comp, self._context_stack)
예제 #16
0
 def test_to_value_with_string(self):
   value = value_impl.to_value('a', tf.string,
                               context_stack_impl.context_stack)
   self.assertIsInstance(value, value_base.Value)
   self.assertEqual(str(value.type_signature), 'string')
예제 #17
0
 def federated_mean(arg):
     one = value_impl.ValueImpl(
         building_block_factory.create_generic_constant(
             arg.type_signature, 1), context_stack)
     arg = value_impl.to_value([arg, one], None, context_stack)
     return federated_weighted_mean(arg)
예제 #18
0
 def test_to_value_with_np_float64(self):
   value = value_impl.to_value(
       np.float64(1.0), tf.float64, context_stack_impl.context_stack)
   self.assertIsInstance(value, value_base.Value)
   self.assertEqual(str(value.type_signature), 'float64')
예제 #19
0
  def federated_mean(self, value, weight):
    """Implements `federated_mean` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      weight: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    # TODO(b/113112108): Possibly relax the constraints on numeric types, and
    # inject implicit casts where appropriate. For instance, we might want to
    # allow `tf.int32` values as the input, and automatically cast them to
    # `tf.float321 before invoking the average, thus producing a floating-point
    # result.

    # TODO(b/120439632): Possibly allow the weight to be either structured or
    # non-scalar, e.g., for the case of averaging a convolutional layer, when
    # we would want to use a different weight for every filter, and where it
    # might be cumbersome for users to have to manually slice and assemble a
    # variable.

    value = value_impl.to_value(value, None, self._context_stack)
    type_utils.check_federated_value_placement(value, placements.CLIENTS,
                                               'value to be averaged')
    if not type_utils.is_average_compatible(value.type_signature):
      raise TypeError(
          'The value type {} is not compatible with the average operator.'
          .format(str(value.type_signature)))

    if weight is not None:
      weight = value_impl.to_value(weight, None, self._context_stack)
      type_utils.check_federated_value_placement(weight, placements.CLIENTS,
                                                 'weight to use in averaging')
      py_typecheck.check_type(weight.type_signature.member,
                              computation_types.TensorType)
      if weight.type_signature.member.shape.ndims != 0:
        raise TypeError('The weight type {} is not a federated scalar.'.format(
            str(weight.type_signature)))
      if not (weight.type_signature.member.dtype.is_integer or
              weight.type_signature.member.dtype.is_floating):
        raise TypeError('The weight type {} is not a federated integer or '
                        'floating-point tensor.'.format(
                            str(weight.type_signature)))

    result_type = computation_types.FederatedType(value.type_signature.member,
                                                  placements.SERVER, True)

    if weight is not None:
      intrinsic = value_impl.ValueImpl(
          computation_building_blocks.Intrinsic(
              intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri,
              computation_types.FunctionType(
                  [value.type_signature, weight.type_signature], result_type)),
          self._context_stack)
      return intrinsic(value, weight)
    else:
      intrinsic = value_impl.ValueImpl(
          computation_building_blocks.Intrinsic(
              intrinsic_defs.FEDERATED_MEAN.uri,
              computation_types.FunctionType(value.type_signature,
                                             result_type)), self._context_stack)
      return intrinsic(value)
예제 #20
0
 def test_to_value_with_np_ndarray(self):
   value = value_impl.to_value(
       np.ndarray(shape=(2, 0), dtype=np.int32), (tf.int32, [2, 0]),
       context_stack_impl.context_stack)
   self.assertIsInstance(value, value_base.Value)
   self.assertEqual(str(value.type_signature), 'int32[2,0]')
예제 #21
0
  def federated_zip(self, value):
    """Implements `federated_zip` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    # TODO(b/113112108): Extend this to accept *args.

    # TODO(b/113112108): We use the iterate/unwrap approach below because
    # our type system is not powerful enough to express the concept of
    # "an operation that takes tuples of T of arbitrary length", and therefore
    # the intrinsic federated_zip must only take a fixed number of arguments,
    # here fixed at 2. There are other potential approaches to getting around
    # this problem (e.g. having the operator act on sequences and thereby
    # sidestepping the issue) which we may want to explore.
    value = value_impl.to_value(value, None, self._context_stack)
    py_typecheck.check_type(value, value_base.Value)
    py_typecheck.check_type(value.type_signature,
                            computation_types.NamedTupleType)
    elements_to_zip = anonymous_tuple.to_elements(value.type_signature)
    num_elements = len(elements_to_zip)
    py_typecheck.check_type(elements_to_zip[0][1],
                            computation_types.FederatedType)
    output_placement = elements_to_zip[0][1].placement
    zip_apply_fn = {
        placements.CLIENTS: self.federated_map,
        placements.SERVER: self.federated_apply
    }
    if output_placement not in zip_apply_fn:
      raise TypeError(
          'federated_zip only supports components with CLIENTS or '
          'SERVER placement, [{}] is unsupported'.format(output_placement))
    if num_elements == 0:
      raise ValueError('federated_zip is only supported on nonempty tuples.')
    if num_elements == 1:
      input_ref = computation_building_blocks.Reference(
          'value_in', elements_to_zip[0][1].member)
      output_tuple = computation_building_blocks.Tuple([(elements_to_zip[0][0],
                                                         input_ref)])
      lam = computation_building_blocks.Lambda('value_in',
                                               input_ref.type_signature,
                                               output_tuple)
      return zip_apply_fn[output_placement](lam, value[0])
    for _, elem in elements_to_zip:
      py_typecheck.check_type(elem, computation_types.FederatedType)
      if elem.placement is not output_placement:
        raise TypeError(
            'The elements of the named tuple to zip must be placed at {}.'
            .format(output_placement))
    named_comps = [(elements_to_zip[k][0],
                    value_impl.ValueImpl.get_comp(value[k]))
                   for k in range(len(value))]
    tuple_to_zip = anonymous_tuple.AnonymousTuple(
        [named_comps[0], named_comps[1]])
    zipped = value_utils.zip_two_tuple(
        value_impl.to_value(tuple_to_zip, None, self._context_stack),
        self._context_stack)
    inputs = value_impl.to_value(
        computation_building_blocks.Reference('inputs',
                                              zipped.type_signature.member),
        None, self._context_stack)
    flatten_fn = value_impl.to_value(
        computation_building_blocks.Lambda(
            'inputs', zipped.type_signature.member,
            value_impl.ValueImpl.get_comp(inputs)), None, self._context_stack)
    for k in range(2, num_elements):
      zipped = value_utils.zip_two_tuple(
          value_impl.to_value(
              computation_building_blocks.Tuple(
                  [value_impl.ValueImpl.get_comp(zipped), named_comps[k]]),
              None, self._context_stack), self._context_stack)
      last_zipped = (named_comps[k][0], named_comps[k][1].type_signature.member)
      flatten_fn = value_utils.flatten_first_index(flatten_fn, last_zipped,
                                                   self._context_stack)
    return zip_apply_fn[output_placement](flatten_fn, zipped)
예제 #22
0
 def test_to_value_raises_type_error(self):
   with self.assertRaises(TypeError):
     value_impl.to_value(10, tf.bool, context_stack_impl.context_stack)
예제 #23
0
  def federated_aggregate(self, value, zero, accumulate, merge, report):
    """Implements `federated_aggregate` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      zero: As in `api/intrinsics.py`.
      accumulate: As in `api/intrinsics.py`.
      merge: As in `api/intrinsics.py`.
      report: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    value = value_impl.to_value(value, None, self._context_stack)
    type_utils.check_federated_value_placement(value, placements.CLIENTS,
                                               'value to be aggregated')

    zero = value_impl.to_value(zero, None, self._context_stack)
    py_typecheck.check_type(zero, value_base.Value)

    # TODO(b/113112108): We need a check here that zero does not have federated
    # constituents.

    accumulate = value_impl.to_value(accumulate, None, self._context_stack)
    merge = value_impl.to_value(merge, None, self._context_stack)
    report = value_impl.to_value(report, None, self._context_stack)
    for op in [accumulate, merge, report]:
      py_typecheck.check_type(op, value_base.Value)
      py_typecheck.check_type(op.type_signature, computation_types.FunctionType)

    accumulate_type_expected = type_constructors.reduction_op(
        zero.type_signature, value.type_signature.member)
    merge_type_expected = type_constructors.reduction_op(
        zero.type_signature, zero.type_signature)
    report_type_expected = computation_types.FunctionType(
        zero.type_signature, report.type_signature.result)
    for op_name, op, type_expected in [
        ('accumulate', accumulate, accumulate_type_expected),
        ('merge', merge, merge_type_expected),
        ('report', report, report_type_expected)
    ]:
      if not type_utils.is_assignable_from(type_expected, op.type_signature):
        raise TypeError('Expected parameter `{}` to be of type {}, '
                        'but received {} instead.'.format(
                            op_name, str(type_expected),
                            str(op.type_signature)))

    result_type = computation_types.FederatedType(report.type_signature.result,
                                                  placements.SERVER, True)
    intrinsic = value_impl.ValueImpl(
        computation_building_blocks.Intrinsic(
            intrinsic_defs.FEDERATED_AGGREGATE.uri,
            computation_types.FunctionType([
                value.type_signature, zero.type_signature,
                accumulate_type_expected, merge_type_expected,
                report_type_expected
            ], result_type)), self._context_stack)
    return intrinsic(value, zero, accumulate, merge, report)
예제 #24
0
 def _to_value(cbb):
   return value_impl.to_value(cbb, None, context_stack_impl.context_stack)
예제 #25
0
    def federated_aggregate(self, value, zero, accumulate, merge, report):
        """Implements `federated_aggregate` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        value = value_utils.ensure_federated_value(value,
                                                   placement_literals.CLIENTS,
                                                   'value to be aggregated')

        zero = value_impl.to_value(zero, None, self._context_stack)
        py_typecheck.check_type(zero, value_base.Value)
        accumulate = value_impl.to_value(
            accumulate,
            None,
            self._context_stack,
            parameter_type_hint=computation_types.StructType(
                [zero.type_signature, value.type_signature.member]))
        merge = value_impl.to_value(
            merge,
            None,
            self._context_stack,
            parameter_type_hint=computation_types.StructType(
                [zero.type_signature, zero.type_signature]))
        report = value_impl.to_value(report,
                                     None,
                                     self._context_stack,
                                     parameter_type_hint=zero.type_signature)
        for op in [accumulate, merge, report]:
            py_typecheck.check_type(op, value_base.Value)
            py_typecheck.check_type(op.type_signature,
                                    computation_types.FunctionType)

        if not accumulate.type_signature.parameter[0].is_assignable_from(
                zero.type_signature):
            raise TypeError('Expected `zero` to be assignable to type {}, '
                            'but was of incompatible type {}.'.format(
                                accumulate.type_signature.parameter[0],
                                zero.type_signature))

        accumulate_type_expected = type_factory.reduction_op(
            accumulate.type_signature.result, value.type_signature.member)
        merge_type_expected = type_factory.reduction_op(
            accumulate.type_signature.result, accumulate.type_signature.result)
        report_type_expected = computation_types.FunctionType(
            merge.type_signature.result, report.type_signature.result)
        for op_name, op, type_expected in [
            ('accumulate', accumulate, accumulate_type_expected),
            ('merge', merge, merge_type_expected),
            ('report', report, report_type_expected)
        ]:
            if not type_expected.is_assignable_from(op.type_signature):
                raise TypeError(
                    'Expected parameter `{}` to be of type {}, but received {} instead.'
                    .format(op_name, type_expected, op.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        zero = value_impl.ValueImpl.get_comp(zero)
        accumulate = value_impl.ValueImpl.get_comp(accumulate)
        merge = value_impl.ValueImpl.get_comp(merge)
        report = value_impl.ValueImpl.get_comp(report)

        comp = building_block_factory.create_federated_aggregate(
            value, zero, accumulate, merge, report)
        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)
예제 #26
0
 def _(x):
   x = value_impl.to_value(x, None, _context_stack)
   value_utils.ensure_federated_value(x)
   return x
예제 #27
0
 def test_to_value_for_placement_literals(self):
   clients = value_impl.to_value(placements.CLIENTS, None,
                                 context_stack_impl.context_stack)
   self.assertIsInstance(clients, value_base.Value)
   self.assertEqual(str(clients.type_signature), 'placement')
   self.assertEqual(str(clients), 'CLIENTS')
 def ingest(self, val, type_spec):
     val = value_impl.to_value(val, type_spec, self._context_stack)
     type_utils.check_type(val, type_spec)
     return val
예제 #29
0
    def sequence_reduce(self, value, zero, op):
        """Implements `sequence_reduce` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      zero: As in `api/intrinsics.py`.
      op: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        value = value_impl.to_value(value, None, self._context_stack)
        zero = value_impl.to_value(zero, None, self._context_stack)
        op = value_impl.to_value(op, None, self._context_stack)
        if isinstance(value.type_signature, computation_types.SequenceType):
            element_type = value.type_signature.element
        else:
            py_typecheck.check_type(value.type_signature,
                                    computation_types.FederatedType)
            py_typecheck.check_type(value.type_signature.member,
                                    computation_types.SequenceType)
            element_type = value.type_signature.member.element
        op_type_expected = type_constructors.reduction_op(
            zero.type_signature, element_type)
        if not type_utils.is_assignable_from(op_type_expected,
                                             op.type_signature):
            raise TypeError('Expected an operator of type {}, got {}.'.format(
                str(op_type_expected), str(op.type_signature)))
        sequence_reduce_building_block = computation_building_blocks.Intrinsic(
            intrinsic_defs.SEQUENCE_REDUCE.uri,
            computation_types.FunctionType([
                computation_types.SequenceType(element_type),
                zero.type_signature, op.type_signature
            ], zero.type_signature))
        if isinstance(value.type_signature, computation_types.SequenceType):
            sequence_reduce_intrinsic = value_impl.ValueImpl(
                sequence_reduce_building_block, self._context_stack)
            return sequence_reduce_intrinsic(value, zero, op)
        else:
            federated_mapping_fn_building_block = computation_building_blocks.Lambda(
                'arg', computation_types.SequenceType(element_type),
                computation_building_blocks.Call(
                    sequence_reduce_building_block,
                    computation_building_blocks.Tuple([
                        computation_building_blocks.Reference(
                            'arg',
                            computation_types.SequenceType(element_type)),
                        value_impl.ValueImpl.get_comp(zero),
                        value_impl.ValueImpl.get_comp(op)
                    ])))
            federated_mapping_fn = value_impl.ValueImpl(
                federated_mapping_fn_building_block, self._context_stack)
            if value.type_signature.placement is placements.SERVER:
                return self.federated_apply(federated_mapping_fn, value)
            elif value.type_signature.placement is placements.CLIENTS:
                return self.federated_map(federated_mapping_fn, value)
            else:
                raise TypeError('Unsupported placement {}.'.format(
                    str(value.type_signature.placement)))