Beispiel #1
0
 def test_two_tuple_zip_fails_bad_args(self):
   server_test_ref = computation_building_blocks.Reference(
       'test',
       computation_types.NamedTupleType([
           computation_types.FederatedType(tf.int32, placements.CLIENTS, True),
           computation_types.FederatedType(tf.bool, placements.SERVER, True)
       ]))
   with self.assertRaisesRegexp(TypeError, 'should be placed at CLIENTS'):
     _ = value_utils.zip_two_tuple(
         value_impl.to_value(server_test_ref, None, _context_stack),
         _context_stack)
   client_test_ref = computation_building_blocks.Reference(
       'test',
       computation_types.NamedTupleType([
           computation_types.FederatedType(tf.int32, placements.CLIENTS, True),
           computation_types.FederatedType(tf.bool, placements.CLIENTS, True)
       ]))
   with self.assertRaisesRegexp(TypeError, '(Expected).*(Value)'):
     _ = value_utils.zip_two_tuple(client_test_ref, _context_stack)
   three_tuple_test_ref = computation_building_blocks.Reference(
       'three_tuple_test',
       computation_types.NamedTupleType([
           computation_types.FederatedType(tf.int32, placements.CLIENTS, True),
           computation_types.FederatedType(tf.int32, placements.CLIENTS, True),
           computation_types.FederatedType(tf.int32, placements.CLIENTS, True)
       ]))
   with self.assertRaisesRegexp(ValueError, 'must be a 2-tuple'):
     _ = value_utils.zip_two_tuple(
         value_impl.to_value(three_tuple_test_ref, None, _context_stack),
         _context_stack)
Beispiel #2
0
 def test_two_tuple_zip_with_client_all_equal_int_and_bool(self):
   test_ref = computation_building_blocks.Reference(
       'test',
       computation_types.NamedTupleType([
           computation_types.FederatedType(tf.int32, placements.CLIENTS, True),
           computation_types.FederatedType(tf.bool, placements.CLIENTS, True)
       ]))
   zipped = value_utils.zip_two_tuple(
       value_impl.to_value(test_ref, None, _context_stack), _context_stack)
   self.assertEqual(str(zipped.type_signature), '{<int32,bool>}@CLIENTS')
  def federated_zip(self, value):
    """Implements `federated_zip` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    # TODO(b/113112108): Extend this to accept *args.

    # TODO(b/113112108): We use the iterate/unwrap approach below because
    # our type system is not powerful enough to express the concept of
    # "an operation that takes tuples of T of arbitrary length", and therefore
    # the intrinsic federated_zip must only take a fixed number of arguments,
    # here fixed at 2. There are other potential approaches to getting around
    # this problem (e.g. having the operator act on sequences and thereby
    # sidestepping the issue) which we may want to explore.
    value = value_impl.to_value(value, None, self._context_stack)
    py_typecheck.check_type(value, value_base.Value)
    py_typecheck.check_type(value.type_signature,
                            computation_types.NamedTupleType)
    elements_to_zip = anonymous_tuple.to_elements(value.type_signature)
    num_elements = len(elements_to_zip)
    py_typecheck.check_type(elements_to_zip[0][1],
                            computation_types.FederatedType)
    output_placement = elements_to_zip[0][1].placement
    zip_apply_fn = {
        placements.CLIENTS: self.federated_map,
        placements.SERVER: self.federated_apply
    }
    if output_placement not in zip_apply_fn:
      raise TypeError(
          'federated_zip only supports components with CLIENTS or '
          'SERVER placement, [{}] is unsupported'.format(output_placement))
    if num_elements == 0:
      raise ValueError('federated_zip is only supported on nonempty tuples.')
    if num_elements == 1:
      input_ref = computation_building_blocks.Reference(
          'value_in', elements_to_zip[0][1].member)
      output_tuple = computation_building_blocks.Tuple([(elements_to_zip[0][0],
                                                         input_ref)])
      lam = computation_building_blocks.Lambda(
          'value_in', input_ref.type_signature, output_tuple)
      return zip_apply_fn[output_placement](lam, value[0])
    for _, elem in elements_to_zip:
      py_typecheck.check_type(elem, computation_types.FederatedType)
      if elem.placement is not output_placement:
        raise TypeError(
            'The elements of the named tuple to zip must be placed at {}.'
            .format(output_placement))
    named_comps = [(elements_to_zip[k][0],
                    value_impl.ValueImpl.get_comp(value[k]))
                   for k in range(len(value))]
    tuple_to_zip = anonymous_tuple.AnonymousTuple(
        [named_comps[0], named_comps[1]])
    zipped = value_utils.zip_two_tuple(
        value_impl.to_value(tuple_to_zip, None, self._context_stack),
        self._context_stack)
    inputs = value_impl.to_value(
        computation_building_blocks.Reference(
            'inputs', zipped.type_signature.member), None, self._context_stack)
    flatten_func = value_impl.to_value(
        computation_building_blocks.Lambda(
            'inputs', zipped.type_signature.member,
            value_impl.ValueImpl.get_comp(inputs)), None, self._context_stack)
    for k in range(2, num_elements):
      zipped = value_utils.zip_two_tuple(
          value_impl.to_value(
              computation_building_blocks.Tuple(
                  [value_impl.ValueImpl.get_comp(zipped), named_comps[k]]),
              None, self._context_stack), self._context_stack)
      last_zipped = (named_comps[k][0], named_comps[k][1].type_signature.member)
      flatten_func = value_utils.flatten_first_index(flatten_func, last_zipped,
                                                     self._context_stack)
    return zip_apply_fn[output_placement](flatten_func, zipped)