예제 #1
0
    def test_with_block(self):
        ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor())
        loop = asyncio.get_event_loop()

        f_type = computation_types.FunctionType(tf.int32, tf.int32)
        a = computation_building_blocks.Reference(
            'a',
            computation_types.NamedTupleType([('f', f_type), ('x', tf.int32)]))
        ret = computation_building_blocks.Block(
            [('f', computation_building_blocks.Selection(a, name='f')),
             ('x', computation_building_blocks.Selection(a, name='x'))],
            computation_building_blocks.Call(
                computation_building_blocks.Reference('f', f_type),
                computation_building_blocks.Call(
                    computation_building_blocks.Reference('f', f_type),
                    computation_building_blocks.Reference('x', tf.int32))))
        comp = computation_building_blocks.Lambda(a.name, a.type_signature,
                                                  ret)

        @computations.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        v1 = loop.run_until_complete(
            ex.create_value(comp.proto, comp.type_signature))
        v2 = loop.run_until_complete(ex.create_value(add_one))
        v3 = loop.run_until_complete(ex.create_value(10, tf.int32))
        v4 = loop.run_until_complete(
            ex.create_tuple(
                anonymous_tuple.AnonymousTuple([('f', v2), ('x', v3)])))
        v5 = loop.run_until_complete(ex.create_call(v1, v4))
        result = loop.run_until_complete(v5.compute())
        self.assertEqual(result.numpy(), 12)
 def test_basic_functionality_of_lambda_class(self):
     arg_name = 'arg'
     arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)),
                 ('x', tf.int32)]
     arg = computation_building_blocks.Reference(arg_name, arg_type)
     arg_f = computation_building_blocks.Selection(arg, name='f')
     arg_x = computation_building_blocks.Selection(arg, name='x')
     x = computation_building_blocks.Lambda(
         arg_name, arg_type,
         computation_building_blocks.Call(
             arg_f, computation_building_blocks.Call(arg_f, arg_x)))
     self.assertEqual(str(x.type_signature),
                      '(<f=(int32 -> int32),x=int32> -> int32)')
     self.assertEqual(x.parameter_name, arg_name)
     self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>')
     self.assertEqual(x.result.tff_repr, 'arg.f(arg.f(arg.x))')
     arg_type_repr = (
         'NamedTupleType(['
         '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), '
         '(\'x\', TensorType(tf.int32))])')
     self.assertEqual(
         repr(x), 'Lambda(\'arg\', {0}, '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format(
             arg_type_repr))
     self.assertEqual(x.tff_repr, '(arg -> arg.f(arg.f(arg.x)))')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'lambda')
     self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name)
     self.assertEqual(str(getattr(x_proto, 'lambda').result),
                      str(x.result.proto))
     self._serialize_deserialize_roundtrip_test(x)
예제 #3
0
def construct_binary_operator_with_upcast(type_signature, operator):
  """Constructs lambda upcasting its argument and applying `operator`.

  The concept of upcasting is explained further in the docstring for
  `apply_binary_operator_with_upcast`.

  Notice that since we are constructing a function here, e.g. for the body
  of an intrinsic, the function we are constructing must be reducible to
  TensorFlow. Therefore `type_signature` can only have named tuple or tensor
  type elements; that is, we cannot handle federated types here in a generic
  way.

  Args:
    type_signature: Value convertible to `computation_types.NamedTupleType`,
      with two elements, both of the same type or the second able to be upcast
      to the first, as explained in `apply_binary_operator_with_upcast`, and
      both containing only tuples and tensors in their type tree.
    operator: Callable defining the operator.

  Returns:
    A `computation_building_blocks.Lambda` encapsulating a function which
    upcasts the second element of its argument and applies the binary
    operator.
  """
  py_typecheck.check_callable(operator)
  type_signature = computation_types.to_type(type_signature)
  _check_generic_operator_type(type_signature)
  ref_to_arg = computation_building_blocks.Reference('binary_operator_arg',
                                                     type_signature)

  def _pack_into_type(to_pack, type_spec):
    """Pack Tensor value `to_pack` into the nested structure `type_spec`."""
    if isinstance(type_spec, computation_types.NamedTupleType):
      elems = anonymous_tuple.to_elements(type_spec)
      packed_elems = [(elem_name, _pack_into_type(to_pack, elem_type))
                      for elem_name, elem_type in elems]
      return computation_building_blocks.Tuple(packed_elems)
    elif isinstance(type_spec, computation_types.TensorType):
      expand_fn = computation_constructing_utils.construct_tensorflow_to_broadcast_scalar(
          to_pack.type_signature.dtype, type_spec.shape)
      return computation_building_blocks.Call(expand_fn, to_pack)

  y_ref = computation_building_blocks.Selection(ref_to_arg, index=1)
  first_arg = computation_building_blocks.Selection(ref_to_arg, index=0)

  if type_utils.are_equivalent_types(first_arg.type_signature,
                                     y_ref.type_signature):
    second_arg = y_ref
  else:
    second_arg = _pack_into_type(y_ref, first_arg.type_signature)

  fn = computation_constructing_utils.construct_tensorflow_binary_operator(
      first_arg.type_signature, operator)
  packed = computation_building_blocks.Tuple([first_arg, second_arg])
  operated = computation_building_blocks.Call(fn, packed)
  lambda_encapsulating_op = computation_building_blocks.Lambda(
      ref_to_arg.name, ref_to_arg.type_signature, operated)
  return lambda_encapsulating_op
def _create_chain_zipped_values(value):
    r"""Creates a chain of called federated zip with two values.

                Block--------
               /             \
  [value=Tuple]               Call
         |                   /    \
         [Comp1,    Intrinsic      Tuple
          Comp2,                   |
          ...]                     [Call,  Sel(n)]
                                   /    \        \
                          Intrinsic      Tuple    Ref(value)
                                         |
                                         [Sel(0),       Sel(1)]
                                                \             \
                                                 Ref(value)    Ref(value)

  NOTE: This function is intended to be used in conjunction with
  `_create_fn_to_append_chain_zipped_values` and will drop the tuple names. The
  names will be added back to the resulting computation when the zipped values
  are mapped to a function that flattens the chain. This nested zip -> flatten
  structure must be used since length of a named tuple type in the TFF type
  system is an element of the type proper. That is, a named tuple type of
  length 2 is a different type than a named tuple type of length 3, they are
  not simply items with the same type and different values, as would be the
  case if you were thinking of these as Python `list`s. It may be better to
  think of named tuple types in TFF as more like `struct`s.

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_types.NamedTupleType` containing at
      least two elements.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
    ValueError: If `value` does not contain at least two elements.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    named_type_signatures = anonymous_tuple.to_elements(value.type_signature)
    length = len(named_type_signatures)
    if length < 2:
        raise ValueError(
            'Expected a value with at least two elements, received {} elements.'
            .format(named_type_signatures))
    ref = computation_building_blocks.Reference('value', value.type_signature)
    symbols = ((ref.name, value), )
    sel_0 = computation_building_blocks.Selection(ref, index=0)
    result = sel_0
    for i in range(1, length):
        sel = computation_building_blocks.Selection(ref, index=i)
        values = computation_building_blocks.Tuple((result, sel))
        result = _create_zip_two_values(values)
    return computation_building_blocks.Block(symbols, result)
예제 #5
0
def _create_chain_zipped_values(value):
    r"""Creates a chain of called federated zip with two values.

                Block--------
               /             \
  [value=Tuple]               Call
         |                   /    \
         [Comp1,    Intrinsic      Tuple
          Comp2,                   |
          ...]                     [Call,  Sel(n)]
                                   /    \        \
                          Intrinsic      Tuple    Ref(value)
                                         |
                                         [Sel(0),       Sel(1)]
                                                \             \
                                                 Ref(value)    Ref(value)

  NOTE: This function is intended to be used in conjunction with
  `_create_fn_to_append_chain_zipped_values` and will drop the tuple names. The
  names will be added back to the resulting computation when the zipped values
  are mapped to a function that flattens the chain.

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_types.NamedTupleType` containing at
      least two elements.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
    ValueError: If `value` does not contain at least two elements.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    named_type_signatures = anonymous_tuple.to_elements(value.type_signature)
    length = len(named_type_signatures)
    if length < 2:
        raise ValueError(
            'Expected a value with at least two elements, received {} elements.'
            .format(named_type_signatures))
    first_name, _ = named_type_signatures[0]
    ref = computation_building_blocks.Reference('value', value.type_signature)
    symbols = ((ref.name, value), )
    sel_0 = computation_building_blocks.Selection(ref, index=0)
    result = (first_name, sel_0)
    for i in range(1, length):
        name, _ = named_type_signatures[i]
        sel = computation_building_blocks.Selection(ref, index=i)
        values = computation_building_blocks.Tuple((result, (name, sel)))
        result = _create_zip_two_values(values)
    return computation_building_blocks.Block(symbols, result)
예제 #6
0
def create_computation_appending(comp1, comp2):
    r"""Returns a block appending `comp2` to `comp1`.

                Block
               /     \
  [comps=Tuple]       Tuple
         |            |
    [Comp, Comp]      [Sel(0), ...,  Sel(0),   Sel(1)]
                             \             \         \
                              Sel(0)        Sel(n)    Ref(comps)
                                    \             \
                                     Ref(comps)    Ref(comps)

  Args:
    comp1: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_type.NamedTupleType`.
    comp2: A `computation_building_blocks.ComputationBuildingBlock` or a named
      computation (a tuple pair of name, computation) representing a single
      element of an `anonymous_tuple.AnonymousTuple`.

  Returns:
    A `computation_building_blocks.Block`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        comp1, computation_building_blocks.ComputationBuildingBlock)
    if isinstance(comp2, computation_building_blocks.ComputationBuildingBlock):
        name2 = None
    elif py_typecheck.is_name_value_pair(
            comp2,
            name_required=False,
            value_type=computation_building_blocks.ComputationBuildingBlock):
        name2, comp2 = comp2
    else:
        raise TypeError('Unexpected tuple element: {}.'.format(comp2))
    comps = computation_building_blocks.Tuple((comp1, comp2))
    ref = computation_building_blocks.Reference('comps', comps.type_signature)
    sel_0 = computation_building_blocks.Selection(ref, index=0)
    elements = []
    named_type_signatures = anonymous_tuple.to_elements(comp1.type_signature)
    for index, (name, _) in enumerate(named_type_signatures):
        sel = computation_building_blocks.Selection(sel_0, index=index)
        elements.append((name, sel))
    sel_1 = computation_building_blocks.Selection(ref, index=1)
    elements.append((name2, sel_1))
    result = computation_building_blocks.Tuple(elements)
    symbols = ((ref.name, comps), )
    return computation_building_blocks.Block(symbols, result)
예제 #7
0
    def _transform_functional_args(comps):
        r"""Transforms the functional computations `comps`.

    Given a computation containing `n` called intrinsics with `m` arguments,
    this function constructs the following computation from the functional
    arguments of the called intrinsic:

                    Block
                   /     \
         [fn=Tuple]       Lambda(arg)
             |                       \
    [Comp(f1), Comp(f2), ...]         Tuple
                                      |
                                 [Call,                  Call, ...]
                                 /    \                 /    \
                           Sel(0)      Sel(0)     Sel(1)      Sel(1)
                          /           /          /           /
                   Ref(fn)    Ref(arg)    Ref(fn)    Ref(arg)

    with one `computation_building_blocks.Call` for each `n`. This computation
    represents one of `m` arguments that should be passed to the call of the
    transformed computation.

    Args:
      comps: a Python list of computations.

    Returns:
      A `computation_building_blocks.Block`.
    """
        functions = computation_building_blocks.Tuple(comps)
        functions_name = six.next(name_generator)
        functions_ref = computation_building_blocks.Reference(
            functions_name, functions.type_signature)
        arg_name = six.next(name_generator)
        arg_type = [element.type_signature.parameter for element in comps]
        arg_ref = computation_building_blocks.Reference(arg_name, arg_type)
        elements = []
        for index in range(len(comps)):
            sel_fn = computation_building_blocks.Selection(functions_ref,
                                                           index=index)
            sel_arg = computation_building_blocks.Selection(arg_ref,
                                                            index=index)
            call = computation_building_blocks.Call(sel_fn, sel_arg)
            elements.append(call)
        calls = computation_building_blocks.Tuple(elements)
        fn = computation_building_blocks.Lambda(arg_ref.name,
                                                arg_ref.type_signature, calls)
        return computation_building_blocks.Block(
            ((functions_ref.name, functions), ), fn)
예제 #8
0
 def __setattr__(self, name, value):
     py_typecheck.check_type(name, six.string_types)
     if not isinstance(self._comp.type_signature,
                       computation_types.NamedTupleType):
         raise TypeError(
             'Operator setattr() is only supported for named tuples, but the '
             'object on which it has been invoked is of type {}.'.format(
                 str(self._comp.type_signature)))
     if name not in dir(self._comp.type_signature):
         raise AttributeError(
             'There is no such attribute as \'{}\' in this tuple. '
             'TFF does not allow for assigning to a nonexistent attribute. '
             'If you want to assign to \'{}\', you must create a new named tuple '
             'containing this attribute.'.format(name, name))
     elem_array = []
     type_signature_elements = anonymous_tuple.to_elements(
         self._comp.type_signature)
     for k, v in type_signature_elements:
         if k == name:
             try:
                 value = to_value(value, v, self._context_stack)
             except TypeError:
                 raise TypeError(
                     'Setattr has attempted to set element {} of type {} '
                     'with incompatible item {}.'.format(k, v, value))
             elem_array.append((k, ValueImpl.get_comp(value)))
         else:
             elem_array.append(
                 (k,
                  computation_building_blocks.Selection(self._comp,
                                                        name=k)))
     new_comp = computation_building_blocks.Tuple([(k, v)
                                                   for k, v in elem_array])
     super(ValueImpl, self).__setattr__('_comp', new_comp)
 def test_returns_string_for_comp_with_right_overhang(self):
     ref = computation_building_blocks.Reference('a', tf.int32)
     data = computation_building_blocks.Data('data', tf.int32)
     tup = computation_building_blocks.Tuple([ref, data, data, data, data])
     sel = computation_building_blocks.Selection(tup, index=0)
     fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                             sel)
     comp = computation_building_blocks.Call(fn, data)
     compact_string = comp.compact_representation()
     self.assertEqual(compact_string,
                      '(a -> <a,data,data,data,data>[0])(data)')
     formatted_string = comp.formatted_representation()
     # pyformat: disable
     self.assertEqual(
         formatted_string, '(a -> <\n'
         '  a,\n'
         '  data,\n'
         '  data,\n'
         '  data,\n'
         '  data\n'
         '>[0])(data)')
     # pyformat: enable
     structural_string = comp.structural_representation()
     # pyformat: disable
     self.assertEqual(
         structural_string, '          Call\n'
         '         /    \\\n'
         'Lambda(a)      data\n'
         '|\n'
         'Sel(0)\n'
         '|\n'
         'Tuple\n'
         '|\n'
         '[Ref(a), data, data, data, data]')
예제 #10
0
 def test_propogates_dependence_up_through_selection(self):
   dummy_intrinsic = computation_building_blocks.Intrinsic(
       'dummy_intrinsic', [tf.int32])
   selection = computation_building_blocks.Selection(dummy_intrinsic, index=0)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       selection, dummy_intrinsic_predicate)
   self.assertIn(selection, dependent_nodes)
예제 #11
0
    def test_replace_chained_federated_maps_does_not_replace_unchained_federated_maps(
            self):
        map_arg_type = computation_types.FederatedType(tf.int32,
                                                       placements.CLIENTS)
        map_arg = computation_building_blocks.Reference('arg', map_arg_type)
        inner_lambda = _create_lambda_to_add_one(map_arg.type_signature.member)
        inner_call = _create_call_to_federated_map(inner_lambda, map_arg)
        dummy_tuple = computation_building_blocks.Tuple([inner_call])
        dummy_selection = computation_building_blocks.Selection(dummy_tuple,
                                                                index=0)
        outer_lambda = _create_lambda_to_add_one(
            inner_call.function.type_signature.result.member)
        outer_call = _create_call_to_federated_map(outer_lambda,
                                                   dummy_selection)
        map_lambda = computation_building_blocks.Lambda(
            map_arg.name, map_arg.type_signature, outer_call)
        comp = map_lambda
        uri = intrinsic_defs.FEDERATED_MAP.uri

        self.assertEqual(_get_number_of_intrinsics(comp, uri), 2)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl([(1)]), [3])

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 2)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl([(1)]), [3])
def construct_federated_getattr_comp(comp, name):
  """Function to construct computation for `federated_apply` of `__getattr__`.

  Constructs a `computation_building_blocks.ComputationBuildingBlock`
  which selects `name` from its argument, of type `comp.type_signature.member`,
  an instance of `computation_types.NamedTupleType`.

  Args:
    comp: Instance of `ValueImpl` or
      `computation_building_blocks.ComputationBuildingBlock` with type signature
      `computation_types.FederatedType` whose `member` attribute is of type
      `computation_types.NamedTupleType`.
    name: String name of attribute to grab.

  Returns:
    Instance of `computation_building_blocks.Lambda` which grabs attribute
      according to `name` of its argument.
  """
  py_typecheck.check_type(comp.type_signature, computation_types.FederatedType)
  py_typecheck.check_type(comp.type_signature.member,
                          computation_types.NamedTupleType)
  element_names = [
      x for x, _ in anonymous_tuple.to_elements(comp.type_signature.member)
  ]
  if name not in element_names:
    raise ValueError('The federated value {} has no element of name {}'.format(
        comp, name))
  apply_input = computation_building_blocks.Reference(
      'x', comp.type_signature.member)
  selected = computation_building_blocks.Selection(apply_input, name=name)
  apply_lambda = computation_building_blocks.Lambda(
      'x', apply_input.type_signature, selected)
  return apply_lambda
예제 #13
0
 def __getattr__(self, name):
     py_typecheck.check_type(name, six.string_types)
     if (isinstance(self._comp.type_signature,
                    computation_types.FederatedType)
             and isinstance(self._comp.type_signature.member,
                            computation_types.NamedTupleType)):
         return ValueImpl(
             computation_constructing_utils.
             construct_federated_getattr_call(self._comp,
                                              name), self._context_stack)
     elif not isinstance(self._comp.type_signature,
                         computation_types.NamedTupleType):
         raise TypeError(
             'Operator getattr() is only supported for named tuples, but the '
             'object on which it has been invoked is of type {}.'.format(
                 str(self._comp.type_signature)))
     if name not in dir(self._comp.type_signature):
         raise AttributeError(
             'There is no such attribute as \'{}\' in this tuple.'.format(
                 name))
     if isinstance(self._comp, computation_building_blocks.Tuple):
         return ValueImpl(getattr(self._comp, name), self._context_stack)
     return ValueImpl(
         computation_building_blocks.Selection(self._comp, name=name),
         self._context_stack)
예제 #14
0
 def __getitem__(self, key):
   py_typecheck.check_type(key, (int, slice))
   if (isinstance(self._comp.type_signature, computation_types.FederatedType)
       and isinstance(self._comp.type_signature.member,
                      computation_types.NamedTupleType)):
     return ValueImpl(
         computation_constructing_utils.construct_federated_getitem_call(
             self._comp, key), self._context_stack)
   if not isinstance(self._comp.type_signature,
                     computation_types.NamedTupleType):
     raise TypeError(
         'Operator getitem() is only supported for named tuples, but the '
         'object on which it has been invoked is of type {}.'.format(
             str(self._comp.type_signature)))
   elem_length = len(self._comp.type_signature)
   if isinstance(key, int):
     if key < 0 or key >= elem_length:
       raise IndexError(
           'The index of the selected element {} is out of range.'.format(key))
     if isinstance(self._comp, computation_building_blocks.Tuple):
       return ValueImpl(self._comp[key], self._context_stack)
     else:
       return ValueImpl(
           computation_building_blocks.Selection(self._comp, index=key),
           self._context_stack)
   elif isinstance(key, slice):
     index_range = range(*key.indices(elem_length))
     if not index_range:
       raise IndexError('Attempted to slice 0 elements, which is not '
                        'currently supported.')
     return to_value([self[k] for k in index_range], None, self._context_stack)
 def test_basic_functionality_of_block_class(self):
     x = computation_building_blocks.Block([
         ('x',
          computation_building_blocks.Reference('arg',
                                                (tf.int32, tf.int32))),
         ('y',
          computation_building_blocks.Selection(
              computation_building_blocks.Reference('x',
                                                    (tf.int32, tf.int32)),
              index=0))
     ], computation_building_blocks.Reference('y', tf.int32))
     self.assertEqual(str(x.type_signature), 'int32')
     self.assertEqual([(k, v.tff_repr) for k, v in x.locals],
                      [('x', 'arg'), ('y', 'x[0]')])
     self.assertEqual(x.result.tff_repr, 'y')
     self.assertEqual(
         repr(x), 'Block([(\'x\', Reference(\'arg\', '
         'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)]))), '
         '(\'y\', Selection(Reference(\'x\', '
         'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)])), '
         'index=0))], '
         'Reference(\'y\', TensorType(tf.int32)))')
     self.assertEqual(x.tff_repr, '(let x=arg,y=x[0] in y)')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'block')
     self.assertEqual(str(x_proto.block.result), str(x.result.proto))
     for idx, loc_proto in enumerate(x_proto.block.local):
         loc_name, loc_value = x.locals[idx]
         self.assertEqual(loc_proto.name, loc_name)
         self.assertEqual(str(loc_proto.value), str(loc_value.proto))
         self._serialize_deserialize_roundtrip_test(x)
예제 #16
0
        def _create_block_to_calls(call_names, comps):
            r"""Constructs a transformed block computation from `comps`.

      Given the "original" computation containing `n` called intrinsics
      with `m` arguments, this function constructs the following computation:

                     Block
                    /     \
            fn=Tuple       Lambda(arg)
               |                      \
      [Comp(f1), Comp(f2), ...]        Tuple
                                       |
                                  [Call,                  Call, ...]
                                  /    \                 /    \
                            Sel(0)      Sel(0)     Sel(1)      Sel(1)
                           /           /          /           /
                    Ref(fn)    Ref(arg)    Ref(fn)    Ref(arg)

      with one `computation_building_blocks.Call` for each `n`. This computation
      represents one of `m` arguments that should be passed to the call of the
      "transformed" computation.

      Args:
        call_names: a Python list of names.
        comps: a Python list of computations.

      Returns:
        A `computation_building_blocks.Block`.
      """
            functions = computation_building_blocks.Tuple(
                zip(call_names, comps))
            fn = computation_building_blocks.Reference(
                'fn', functions.type_signature)
            arg_type = [element.type_signature.parameter for element in comps]
            arg = computation_building_blocks.Reference('arg', arg_type)
            elements = []
            for index, name in enumerate(call_names):
                sel_fn = computation_building_blocks.Selection(fn, index=index)
                sel_arg = computation_building_blocks.Selection(arg,
                                                                index=index)
                call = computation_building_blocks.Call(sel_fn, sel_arg)
                elements.append((name, call))
            calls = computation_building_blocks.Tuple(elements)
            lam = computation_building_blocks.Lambda(arg.name,
                                                     arg.type_signature, calls)
            return computation_building_blocks.Block([('fn', functions)], lam)
예제 #17
0
 def _traverse_selection(comp, transform, context_tree, identifier_seq):
   """Helper function holding traversal logic for selection nodes."""
   _ = six.next(identifier_seq)
   transformed_source = _transform_postorder_with_symbol_bindings_switch(
       comp.source, transform, context_tree, identifier_seq)
   transformed_comp = transform(
       computation_building_blocks.Selection(transformed_source, comp.name,
                                             comp.index), context_tree)
   return transformed_comp
예제 #18
0
def construct_federated_getitem_comp(comp, key):
    """Function to construct computation for `federated_apply` of `__getitem__`.

  Constructs a `computation_building_blocks.ComputationBuildingBlock`
  which selects `key` from its argument, of type `comp.type_signature.member`,
  of type `computation_types.NamedTupleType`.

  Args:
    comp: Instance of `computation_building_blocks.ComputationBuildingBlock`
      with type signature `computation_types.FederatedType` whose `member`
      attribute is of type `computation_types.NamedTupleType`.
    key: Instance of `int` or `slice`, key used to grab elements from the member
      of `comp`. implementation of slicing for `ValueImpl` objects with
      `type_signature` `computation_types.NamedTupleType`.

  Returns:
    Instance of `computation_building_blocks.Lambda` which grabs slice
      according to `key` of its argument.
  """
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(comp.type_signature,
                            computation_types.FederatedType)
    py_typecheck.check_type(comp.type_signature.member,
                            computation_types.NamedTupleType)
    py_typecheck.check_type(key, (int, slice))
    apply_input = computation_building_blocks.Reference(
        'x', comp.type_signature.member)
    if isinstance(key, int):
        selected = computation_building_blocks.Selection(apply_input,
                                                         index=key)
    else:
        elems = anonymous_tuple.to_elements(comp.type_signature.member)
        index_range = range(*key.indices(len(elems)))
        elem_list = []
        for k in index_range:
            elem_list.append(
                (elems[k][0],
                 computation_building_blocks.Selection(apply_input, index=k)))
        selected = computation_building_blocks.Tuple(elem_list)
    apply_lambda = computation_building_blocks.Lambda(
        'x', apply_input.type_signature, selected)
    return apply_lambda
예제 #19
0
 def _traverse_selection(comp, transform, context_tree, identifier_seq):
     """Helper function holding traversal logic for selection nodes."""
     _ = six.next(identifier_seq)
     source, source_modified = _transform_postorder_with_symbol_bindings_switch(
         comp.source, transform, context_tree, identifier_seq)
     if source_modified:
         comp = computation_building_blocks.Selection(
             source, comp.name, comp.index)
     comp, comp_modified = transform(comp, context_tree)
     return comp, comp_modified or source_modified
예제 #20
0
def create_federated_unzip(value):
    r"""Creates a tuple of called federated maps or applies.

                Block
               /     \
  [value=Comp]        Tuple
                      |
                      [Call,                        Call, ...]
                      /    \                       /    \
             Intrinsic      Tuple         Intrinsic      Tuple
                            |                            |
                [Lambda(arg), Ref(value)]    [Lambda(arg), Ref(value)]
                            \                            \
                             Sel(0)                       Sel(1)
                                   \                            \
                                    Ref(arg)                     Ref(arg)

  This function returns a tuple of federated values given a `value` with a
  federated tuple type signature.

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_types.NamedTupleType` containing at
      least one element.

  Returns:
    A `computation_building_blocks.Block`.

  Raises:
    TypeError: If any of the types do not match.
    ValueError: If `value` does not contain any elements.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    named_type_signatures = anonymous_tuple.to_elements(
        value.type_signature.member)
    length = len(named_type_signatures)
    if length == 0:
        raise ValueError(
            'federated_zip is only supported on non-empty tuples.')
    value_ref = computation_building_blocks.Reference('value',
                                                      value.type_signature)
    elements = []
    fn_ref = computation_building_blocks.Reference('arg',
                                                   named_type_signatures)
    for index, (name, _) in enumerate(named_type_signatures):
        sel = computation_building_blocks.Selection(fn_ref, index=index)
        fn = computation_building_blocks.Lambda(fn_ref.name,
                                                fn_ref.type_signature, sel)
        intrinsic = create_federated_map_or_apply(fn, value_ref)
        elements.append((name, intrinsic))
    result = computation_building_blocks.Tuple(elements)
    symbols = ((value_ref.name, value), )
    return computation_building_blocks.Block(symbols, result)
 def test_returns_string_for_selection_with_index(self):
     ref = computation_building_blocks.Reference('a', (('b', tf.int32),
                                                       ('c', tf.bool)))
     comp = computation_building_blocks.Selection(ref, index=0)
     compact_string = comp.compact_representation()
     self.assertEqual(compact_string, 'a[0]')
     formatted_string = comp.formatted_representation()
     self.assertEqual(formatted_string, 'a[0]')
     structural_string = comp.structural_representation()
     # pyformat: disable
     self.assertEqual(structural_string, 'Sel(0)\n' '|\n' 'Ref(a)')
def create_federated_zip(value):
    r"""Creates a called federated zip.

            Call
           /    \
  Intrinsic      Tuple
                 |
                 [Comp, Comp]

  This function returns a federated tuple given a `value` with a tuple of
  federated values type signature.

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_types.NamedTupleType` containing at
      least one element.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
    ValueError: If `value` does not contain any elements.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    named_type_signatures = anonymous_tuple.to_elements(value.type_signature)
    names_to_add = [name for name, _ in named_type_signatures]
    length = len(named_type_signatures)
    if length == 0:
        raise ValueError(
            'federated_zip is only supported on non-empty tuples.')
    first_name, first_type_signature = named_type_signatures[0]
    if first_type_signature.placement == placement_literals.CLIENTS:
        map_fn = create_federated_map
    elif first_type_signature.placement == placement_literals.SERVER:
        map_fn = create_federated_apply
    else:
        raise TypeError('Unsupported placement {}.'.format(
            first_type_signature.placement))
    if length == 1:
        ref = computation_building_blocks.Reference(
            'arg', first_type_signature.member)
        values = computation_building_blocks.Tuple(((first_name, ref), ))
        fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                                values)
        sel = computation_building_blocks.Selection(value, index=0)
        return map_fn(fn, sel)
    else:
        zipped_args = _create_chain_zipped_values(value)
        append_fn = _create_fn_to_append_chain_zipped_values(value)
        unnamed_zip = map_fn(append_fn, zipped_args)
        return construct_named_federated_tuple(unnamed_zip, names_to_add)
 def test_intrinsic_construction_clients(self):
   federated_comp = computation_building_blocks.Reference(
       'test',
       computation_types.FederatedType([('a', tf.int32), ('b', tf.bool)],
                                       placement_literals.CLIENTS, True))
   arg_ref = computation_building_blocks.Reference('x', [('a', tf.int32),
                                                         ('b', tf.bool)])
   return_val = computation_building_blocks.Selection(arg_ref, name='a')
   non_federated_fn = computation_building_blocks.Lambda(
       'x', arg_ref.type_signature, return_val)
   intrinsic = computation_constructing_utils.construct_map_or_apply(
       non_federated_fn, federated_comp)
   self.assertEqual(str(intrinsic), 'federated_map')
예제 #24
0
 def test_inline_conflicting_locals(self):
     arg_comp = computation_building_blocks.Reference(
         'arg', [tf.int32, tf.int32])
     selected = computation_building_blocks.Selection(arg_comp, index=0)
     internal_arg = computation_building_blocks.Reference('arg', tf.int32)
     block = computation_building_blocks.Block([('arg', selected)],
                                               internal_arg)
     lam = computation_building_blocks.Lambda('arg',
                                              arg_comp.type_signature,
                                              block)
     self.assertEqual(str(lam), '(arg -> (let arg=arg[0] in arg))')
     inlined = transformations.inline_blocks_with_n_referenced_locals(lam)
     self.assertEqual(str(inlined), '(arg -> (let  in arg[0]))')
예제 #25
0
 def _extract_from_selection(comp):
   """Returns a new computation with all intrinsics extracted."""
   if _is_called_intrinsic(comp.source):
     called_intrinsic = comp.source
     name = six.next(name_generator)
     variables = ((name, called_intrinsic),)
     result = computation_building_blocks.Reference(
         name, called_intrinsic.type_signature)
   else:
     block = comp.source
     variables = block.locals
     result = block.result
   selection = computation_building_blocks.Selection(
       result, name=comp.name, index=comp.index)
   block = computation_building_blocks.Block(variables, selection)
   return _extract_from_block(block)
 def test_basic_functionality_of_selection_class(self):
     x = computation_building_blocks.Reference('foo', [('bar', tf.int32),
                                                       ('baz', tf.bool)])
     y = computation_building_blocks.Selection(x, name='bar')
     self.assertEqual(y.name, 'bar')
     self.assertEqual(y.index, None)
     self.assertEqual(str(y.type_signature), 'int32')
     self.assertEqual(
         repr(y), 'Selection(Reference(\'foo\', NamedTupleType(['
         '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))'
         ', name=\'bar\')')
     self.assertEqual(computation_building_blocks.compact_representation(y),
                      'foo.bar')
     z = computation_building_blocks.Selection(x, name='baz')
     self.assertEqual(str(z.type_signature), 'bool')
     self.assertEqual(computation_building_blocks.compact_representation(z),
                      'foo.baz')
     with self.assertRaises(ValueError):
         _ = computation_building_blocks.Selection(x, name='bak')
     x0 = computation_building_blocks.Selection(x, index=0)
     self.assertEqual(x0.name, None)
     self.assertEqual(x0.index, 0)
     self.assertEqual(str(x0.type_signature), 'int32')
     self.assertEqual(
         repr(x0), 'Selection(Reference(\'foo\', NamedTupleType(['
         '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))'
         ', index=0)')
     self.assertEqual(
         computation_building_blocks.compact_representation(x0), 'foo[0]')
     x1 = computation_building_blocks.Selection(x, index=1)
     self.assertEqual(str(x1.type_signature), 'bool')
     self.assertEqual(
         computation_building_blocks.compact_representation(x1), 'foo[1]')
     with self.assertRaises(ValueError):
         _ = computation_building_blocks.Selection(x, index=2)
     with self.assertRaises(ValueError):
         _ = computation_building_blocks.Selection(x, index=-1)
     y_proto = y.proto
     self.assertEqual(type_serialization.deserialize_type(y_proto.type),
                      y.type_signature)
     self.assertEqual(y_proto.WhichOneof('computation'), 'selection')
     self.assertEqual(str(y_proto.selection.source), str(x.proto))
     self.assertEqual(y_proto.selection.name, 'bar')
     self._serialize_deserialize_roundtrip_test(y)
     self._serialize_deserialize_roundtrip_test(z)
     self._serialize_deserialize_roundtrip_test(x0)
     self._serialize_deserialize_roundtrip_test(x1)
예제 #27
0
        def _create_block_to_chained_calls(comps):
            r"""Constructs a transformed block computation from `comps`.

                     Block
                    /     \
          [fn=Tuple]       Lambda(arg)
              |                       \
      [Comp(y), Comp(x)]               Call
                                      /    \
                                Sel(1)      Call
                               /           /    \
                        Ref(fn)      Sel(0)      Ref(arg)
                                    /
                             Ref(fn)

      (let fn=<y, x> in (arg -> fn[1](fn[0](arg)))

      Args:
        comps: A Python list of computations.

      Returns:
        A `computation_building_blocks.Block`.
      """
            functions = computation_building_blocks.Tuple(comps)
            functions_name = six.next(name_generator)
            functions_ref = computation_building_blocks.Reference(
                functions_name, functions.type_signature)
            arg_name = six.next(name_generator)
            arg_type = comps[0].type_signature.parameter
            arg_ref = computation_building_blocks.Reference(arg_name, arg_type)
            arg = arg_ref
            for index, _ in enumerate(comps):
                fn_sel = computation_building_blocks.Selection(functions_ref,
                                                               index=index)
                call = computation_building_blocks.Call(fn_sel, arg)
                arg = call
            fn = computation_building_blocks.Lambda(arg_ref.name,
                                                    arg_ref.type_signature,
                                                    call)
            return computation_building_blocks.Block(
                ((functions_ref.name, functions), ), fn)
예제 #28
0
def _create_fn_to_append_chain_zipped_values(value):
    r"""Creates a function to append a chain of zipped values.

  Lambda(arg3)
            \
             append([Call,    Sel(1)])
                    /    \            \
        Lambda(arg2)      Sel(0)       Ref(arg3)
                  \             \
                   \             Ref(arg3)
                    \
                     append([Call,    Sel(1)])
                            /    \            \
                Lambda(arg1)      Sel(0)       Ref(arg2)
                            \           \
                             \           Ref(arg2)
                              \
                               Ref(arg1)

  NOTE: This function is intended to be used in conjunction with
  `_create_chain_zipped_values` add will add back the names that were dropped
  when zipping the values.

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_types.NamedTupleType` containing at
      least two elements.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    named_type_signatures = anonymous_tuple.to_elements(value.type_signature)
    length = len(named_type_signatures)
    if length < 2:
        raise ValueError(
            'Expected a value with at least two elements, received {} elements.'
            .format(named_type_signatures))
    first_name, first_type_signature = named_type_signatures[0]
    second_name, second_type_signature = named_type_signatures[1]
    ref_type = computation_types.NamedTupleType((
        (first_name, first_type_signature.member),
        (second_name, second_type_signature.member),
    ))
    ref = computation_building_blocks.Reference('arg', ref_type)
    fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, ref)
    for name, type_signature in named_type_signatures[2:]:
        ref_type = computation_types.NamedTupleType((
            fn.type_signature.parameter,
            (name, type_signature.member),
        ))
        ref = computation_building_blocks.Reference('arg', ref_type)
        sel_0 = computation_building_blocks.Selection(ref, index=0)
        call = computation_building_blocks.Call(fn, sel_0)
        sel_1 = computation_building_blocks.Selection(ref, index=1)
        result = create_computation_appending(call, (name, sel_1))
        fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                                result)
    return fn
def _create_fn_to_append_chain_zipped_values(value):
    r"""Creates a function to append a chain of zipped values.

  Lambda(arg3)
            \
             append([Call,    Sel(1)])
                    /    \            \
        Lambda(arg2)      Sel(0)       Ref(arg3)
                  \             \
                   \             Ref(arg3)
                    \
                     append([Call,    Sel(1)])
                            /    \            \
                Lambda(arg1)      Sel(0)       Ref(arg2)
                            \           \
                             \           Ref(arg2)
                              \
                               Ref(arg1)

  Note that this function will not respect any names it is passed; names
  for tuples will be cached at a higher level than this function and added back
  in a single call to federated map or federated apply.

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_types.NamedTupleType` containing at
      least two elements.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    named_type_signatures = anonymous_tuple.to_elements(value.type_signature)
    length = len(named_type_signatures)
    if length < 2:
        raise ValueError(
            'Expected a value with at least two elements, received {} elements.'
            .format(named_type_signatures))
    _, first_type_signature = named_type_signatures[0]
    _, second_type_signature = named_type_signatures[1]
    ref_type = computation_types.NamedTupleType((
        first_type_signature.member,
        second_type_signature.member,
    ))
    ref = computation_building_blocks.Reference('arg', ref_type)
    fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, ref)
    for _, type_signature in named_type_signatures[2:]:
        ref_type = computation_types.NamedTupleType((
            fn.type_signature.parameter,
            type_signature.member,
        ))
        ref = computation_building_blocks.Reference('arg', ref_type)
        sel_0 = computation_building_blocks.Selection(ref, index=0)
        call = computation_building_blocks.Call(fn, sel_0)
        sel_1 = computation_building_blocks.Selection(ref, index=1)
        result = create_computation_appending(call, sel_1)
        fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                                result)
    return fn
예제 #30
0
def construct_named_tuple_setattr_lambda(named_tuple_signature, name,
                                         value_comp):
    """Constructs a building block for replacing one attribute in a named tuple.

  Returns an instance of `computation_building_blocks.Lambda` which takes an
  argument of type `computation_types.NamedTupleType` and returns a
  `computation_building_blocks.Tuple` which contains all the same elements as
  the argument, except the attribute `name` now has value `value_comp`. The
  Lambda constructed is the analogue of Python's `setattr` for the concrete
  type `named_tuple_signature`.

  Args:
    named_tuple_signature: Instance of `computation_types.NamedTupleType`, the
      type of the argument to the constructed
      `computation_building_blocks.Lambda`.
    name: String name of the attribute in the `named_tuple_signature` to replace
      with `value_comp`. Must be present as a name in `named_tuple_signature;
      otherwise we will raise an `AttributeError`.
    value_comp: Instance of
      `computation_building_blocks.ComputationBuildingBlock`, the value to place
      as attribute `name` in the argument of the returned function.

  Returns:
    An instance of `computation_building_blocks.Block` of functional type
    representing setting attribute `name` to value `value_comp` in its argument
    of type `named_tuple_signature`.

  Raises:
    TypeError: If the types of the arguments don't match the assumptions above.
    AttributeError: If `name` is not present as a named element in
      `named_tuple_signature`
  """
    py_typecheck.check_type(named_tuple_signature,
                            computation_types.NamedTupleType)
    py_typecheck.check_type(name, six.string_types)
    py_typecheck.check_type(
        value_comp, computation_building_blocks.ComputationBuildingBlock)
    value_comp_placeholder = computation_building_blocks.Reference(
        'value_comp_placeholder', value_comp.type_signature)
    lambda_arg = computation_building_blocks.Reference('lambda_arg',
                                                       named_tuple_signature)
    if name not in dir(named_tuple_signature):
        raise AttributeError(
            'There is no such attribute as \'{}\' in this federated tuple. '
            'TFF does not allow for assigning to a nonexistent attribute. '
            'If you want to assign to \'{}\', you must create a new named tuple '
            'containing this attribute.'.format(name, name))
    elements = []
    for idx, (key, element_type) in enumerate(
            anonymous_tuple.to_elements(named_tuple_signature)):
        if key == name:
            if not type_utils.is_assignable_from(element_type,
                                                 value_comp.type_signature):
                raise TypeError(
                    '`setattr` has attempted to set element {} of type {} with incompatible type {}'
                    .format(key, element_type, value_comp.type_signature))
            elements.append((key, value_comp_placeholder))
        else:
            elements.append((key,
                             computation_building_blocks.Selection(lambda_arg,
                                                                   index=idx)))
    return_tuple = computation_building_blocks.Tuple(elements)
    lambda_to_return = computation_building_blocks.Lambda(
        lambda_arg.name, named_tuple_signature, return_tuple)
    symbols = ((value_comp_placeholder.name, value_comp), )
    return computation_building_blocks.Block(symbols, lambda_to_return)