Exemplo n.º 1
0
  def test_flatten_fn_with_names(self, n):
    input_reference = computation_building_blocks.Reference(
        'test', [(str(k), tf.int32) for k in range(n)])
    input_fn = computation_building_blocks.Lambda(
        'test', input_reference.type_signature, input_reference)
    unnamed_type_to_add = (None, computation_types.to_type(tf.int32))
    unnamed_input_type = computation_types.NamedTupleType(
        [input_reference.type_signature, unnamed_type_to_add])
    unnamed_desired_output_type = computation_types.to_type(
        [(str(k), tf.int32) for k in range(n)] + [tf.int32])
    unnamed_desired_fn_type = computation_types.FunctionType(
        unnamed_input_type, unnamed_desired_output_type)
    unnamed_new_fn = value_utils.flatten_first_index(
        value_impl.to_value(input_fn, None, _context_stack),
        unnamed_type_to_add, _context_stack)
    self.assertEqual(
        str(unnamed_new_fn.type_signature), str(unnamed_desired_fn_type))

    named_type_to_add = ('new', tf.int32)
    named_input_type = computation_types.NamedTupleType(
        [input_reference.type_signature, named_type_to_add])
    named_types = [(str(k), tf.int32) for k in range(n)] + [('new', tf.int32)]
    named_desired_output_type = computation_types.to_type(named_types)
    named_desired_fn_type = computation_types.FunctionType(
        named_input_type, named_desired_output_type)
    new_named_fn = value_utils.flatten_first_index(
        value_impl.to_value(input_fn, None, _context_stack), named_type_to_add,
        _context_stack)
    self.assertEqual(
        str(new_named_fn.type_signature), str(named_desired_fn_type))
Exemplo n.º 2
0
 def test_flatten_fn_comp_raises_typeerror(self):
   input_reference = computation_building_blocks.Reference(
       'test', [tf.int32] * 5)
   input_fn = computation_building_blocks.Lambda(
       'test', input_reference.type_signature, input_reference)
   type_to_add = computation_types.NamedTupleType([tf.int32])
   with self.assertRaisesRegexp(TypeError, '(Expected).*(Value)'):
     _ = value_utils.flatten_first_index(input_fn, type_to_add, _context_stack)
Exemplo n.º 3
0
 def test_flatten_fn(self, n):
   input_reference = computation_building_blocks.Reference(
       'test', [tf.int32] * n)
   input_fn = computation_building_blocks.Lambda(
       'test', input_reference.type_signature, input_reference)
   type_to_add = (None, computation_types.to_type(tf.int32))
   input_type = computation_types.NamedTupleType(
       [input_reference.type_signature, type_to_add])
   desired_output_type = computation_types.to_type([tf.int32] * (n + 1))
   desired_fn_type = computation_types.FunctionType(input_type,
                                                    desired_output_type)
   new_fn = value_utils.flatten_first_index(
       value_impl.to_value(input_fn, None, _context_stack), type_to_add,
       _context_stack)
   self.assertEqual(str(new_fn.type_signature), str(desired_fn_type))
Exemplo n.º 4
0
  def federated_zip(self, value):
    """Implements `federated_zip` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    # TODO(b/113112108): Extend this to accept *args.

    # TODO(b/113112108): We use the iterate/unwrap approach below because
    # our type system is not powerful enough to express the concept of
    # "an operation that takes tuples of T of arbitrary length", and therefore
    # the intrinsic federated_zip must only take a fixed number of arguments,
    # here fixed at 2. There are other potential approaches to getting around
    # this problem (e.g. having the operator act on sequences and thereby
    # sidestepping the issue) which we may want to explore.
    value = value_impl.to_value(value, None, self._context_stack)
    py_typecheck.check_type(value, value_base.Value)
    py_typecheck.check_type(value.type_signature,
                            computation_types.NamedTupleType)
    elements_to_zip = anonymous_tuple.to_elements(value.type_signature)
    num_elements = len(elements_to_zip)
    py_typecheck.check_type(elements_to_zip[0][1],
                            computation_types.FederatedType)
    output_placement = elements_to_zip[0][1].placement
    zip_apply_fn = {
        placements.CLIENTS: self.federated_map,
        placements.SERVER: self.federated_apply
    }
    if output_placement not in zip_apply_fn:
      raise TypeError(
          'federated_zip only supports components with CLIENTS or '
          'SERVER placement, [{}] is unsupported'.format(output_placement))
    if num_elements == 0:
      raise ValueError('federated_zip is only supported on nonempty tuples.')
    if num_elements == 1:
      input_ref = computation_building_blocks.Reference(
          'value_in', elements_to_zip[0][1].member)
      output_tuple = computation_building_blocks.Tuple([(elements_to_zip[0][0],
                                                         input_ref)])
      lam = computation_building_blocks.Lambda(
          'value_in', input_ref.type_signature, output_tuple)
      return zip_apply_fn[output_placement](lam, value[0])
    for _, elem in elements_to_zip:
      py_typecheck.check_type(elem, computation_types.FederatedType)
      if elem.placement is not output_placement:
        raise TypeError(
            'The elements of the named tuple to zip must be placed at {}.'
            .format(output_placement))
    named_comps = [(elements_to_zip[k][0],
                    value_impl.ValueImpl.get_comp(value[k]))
                   for k in range(len(value))]
    tuple_to_zip = anonymous_tuple.AnonymousTuple(
        [named_comps[0], named_comps[1]])
    zipped = value_utils.zip_two_tuple(
        value_impl.to_value(tuple_to_zip, None, self._context_stack),
        self._context_stack)
    inputs = value_impl.to_value(
        computation_building_blocks.Reference(
            'inputs', zipped.type_signature.member), None, self._context_stack)
    flatten_func = value_impl.to_value(
        computation_building_blocks.Lambda(
            'inputs', zipped.type_signature.member,
            value_impl.ValueImpl.get_comp(inputs)), None, self._context_stack)
    for k in range(2, num_elements):
      zipped = value_utils.zip_two_tuple(
          value_impl.to_value(
              computation_building_blocks.Tuple(
                  [value_impl.ValueImpl.get_comp(zipped), named_comps[k]]),
              None, self._context_stack), self._context_stack)
      last_zipped = (named_comps[k][0], named_comps[k][1].type_signature.member)
      flatten_func = value_utils.flatten_first_index(flatten_func, last_zipped,
                                                     self._context_stack)
    return zip_apply_fn[output_placement](flatten_func, zipped)