コード例 #1
0
 def test_basic_functionality_of_lambda_class(self):
     arg_name = 'arg'
     arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)),
                 ('x', tf.int32)]
     arg = computation_building_blocks.Reference(arg_name, arg_type)
     arg_f = computation_building_blocks.Selection(arg, name='f')
     arg_x = computation_building_blocks.Selection(arg, name='x')
     x = computation_building_blocks.Lambda(
         arg_name, arg_type,
         computation_building_blocks.Call(
             arg_f, computation_building_blocks.Call(arg_f, arg_x)))
     self.assertEqual(str(x.type_signature),
                      '(<f=(int32 -> int32),x=int32> -> int32)')
     self.assertEqual(x.parameter_name, arg_name)
     self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>')
     self.assertEqual(x.result.tff_repr, 'arg.f(arg.f(arg.x))')
     arg_type_repr = (
         'NamedTupleType(['
         '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), '
         '(\'x\', TensorType(tf.int32))])')
     self.assertEqual(
         repr(x), 'Lambda(\'arg\', {0}, '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format(
             arg_type_repr))
     self.assertEqual(x.tff_repr, '(arg -> arg.f(arg.f(arg.x)))')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'lambda')
     self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name)
     self.assertEqual(str(getattr(x_proto, 'lambda').result),
                      str(x.result.proto))
     self._serialize_deserialize_roundtrip_test(x)
コード例 #2
0
 def test_nested_lambda_block_overwrite_scope_snapshot(self):
     innermost_x = computation_building_blocks.Reference('x', tf.int32)
     inner_lambda = computation_building_blocks.Lambda(
         'x', tf.int32, innermost_x)
     second_x = computation_building_blocks.Reference('x', tf.int32)
     called_lambda = computation_building_blocks.Call(
         inner_lambda, second_x)
     block_input = computation_building_blocks.Reference(
         'block_in', tf.int32)
     lower_block = computation_building_blocks.Block([('x', block_input)],
                                                     called_lambda)
     second_lambda = computation_building_blocks.Lambda(
         'block_in', tf.int32, lower_block)
     third_x = computation_building_blocks.Reference('x', tf.int32)
     second_call = computation_building_blocks.Call(second_lambda, third_x)
     final_input = computation_building_blocks.Data('test_data', tf.int32)
     last_block = computation_building_blocks.Block([('x', final_input)],
                                                    second_call)
     global_snapshot = transformations.scope_count_snapshot(last_block)
     self.assertEqual(
         str(last_block),
         '(let x=test_data in (block_in -> (let x=block_in in (x -> x)(x)))(x))'
     )
     self.assertLen(global_snapshot, 4)
     self.assertEqual(global_snapshot[str(inner_lambda)], {'x': 1})
     self.assertEqual(global_snapshot[str(lower_block)], {'x': 1})
     self.assertEqual(global_snapshot[str(second_lambda)], {'block_in': 1})
     self.assertEqual(global_snapshot[str(last_block)], {'x': 1})
コード例 #3
0
    def test_with_block(self):
        ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor())
        loop = asyncio.get_event_loop()

        f_type = computation_types.FunctionType(tf.int32, tf.int32)
        a = computation_building_blocks.Reference(
            'a',
            computation_types.NamedTupleType([('f', f_type), ('x', tf.int32)]))
        ret = computation_building_blocks.Block(
            [('f', computation_building_blocks.Selection(a, name='f')),
             ('x', computation_building_blocks.Selection(a, name='x'))],
            computation_building_blocks.Call(
                computation_building_blocks.Reference('f', f_type),
                computation_building_blocks.Call(
                    computation_building_blocks.Reference('f', f_type),
                    computation_building_blocks.Reference('x', tf.int32))))
        comp = computation_building_blocks.Lambda(a.name, a.type_signature,
                                                  ret)

        @computations.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        v1 = loop.run_until_complete(
            ex.create_value(comp.proto, comp.type_signature))
        v2 = loop.run_until_complete(ex.create_value(add_one))
        v3 = loop.run_until_complete(ex.create_value(10, tf.int32))
        v4 = loop.run_until_complete(
            ex.create_tuple(
                anonymous_tuple.AnonymousTuple([('f', v2), ('x', v3)])))
        v5 = loop.run_until_complete(ex.create_call(v1, v4))
        result = loop.run_until_complete(v5.compute())
        self.assertEqual(result.numpy(), 12)
コード例 #4
0
 def test_scope_snapshot_called_lambdas(self):
     comp = computation_building_blocks.Tuple(
         [computation_building_blocks.Data('test', tf.int32)])
     input1 = computation_building_blocks.Reference('input1',
                                                    comp.type_signature)
     first_level_call = computation_building_blocks.Call(
         computation_building_blocks.Lambda('input1', input1.type_signature,
                                            input1), comp)
     input2 = computation_building_blocks.Reference(
         'input2', first_level_call.type_signature)
     second_level_call = computation_building_blocks.Call(
         computation_building_blocks.Lambda('input2', input2.type_signature,
                                            input2), first_level_call)
     self.assertEqual(str(second_level_call),
                      '(input2 -> input2)((input1 -> input1)(<test>))')
     global_snapshot = transformations.scope_count_snapshot(
         second_level_call)
     self.assertEqual(
         global_snapshot, {
             '(input2 -> input2)': {
                 'input2': 1
             },
             '(input1 -> input1)': {
                 'input1': 1
             }
         })
コード例 #5
0
    def test_returns_string_for_comp_with_left_overhang(self):
        fn_type = computation_types.FunctionType(tf.int32, tf.int32)
        fn = computation_building_blocks.Reference('a', fn_type)
        proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            lambda: tf.constant(1), None, context_stack_impl.context_stack)
        compiled = computation_building_blocks.CompiledComputation(
            proto, 'bbbbb')
        arg = computation_building_blocks.Call(compiled)

        comp = computation_building_blocks.Call(fn, arg)
        compact_string = computation_building_blocks.compact_representation(
            comp)
        self.assertEqual(compact_string, 'a(comp#bbbbb())')
        formatted_string = computation_building_blocks.formatted_representation(
            comp)
        self.assertEqual(formatted_string, 'a(comp#bbbbb())')
        structural_string = computation_building_blocks.structural_representation(
            comp)
        # pyformat: disable
        self.assertEqual(
            structural_string, '           Call\n'
            '          /    \\\n'
            '    Ref(a)      Call\n'
            '               /\n'
            'Compiled(bbbbb)')
コード例 #6
0
  def test_execute_with_nested_lambda(self):
    int32_add = computation_building_blocks.ComputationBuildingBlock.from_proto(
        computation_impl.ComputationImpl.get_proto(
            computations.tf_computation(tf.add, [tf.int32, tf.int32])))

    curried_int32_add = computation_building_blocks.Lambda(
        'x', tf.int32,
        computation_building_blocks.Lambda(
            'y', tf.int32,
            computation_building_blocks.Call(
                int32_add,
                computation_building_blocks.Tuple(
                    [(None, computation_building_blocks.Reference(
                        'x', tf.int32)),
                     (None, computation_building_blocks.Reference(
                         'y', tf.int32))]))))

    make_10 = computation_building_blocks.ComputationBuildingBlock.from_proto(
        computation_impl.ComputationImpl.get_proto(
            computations.tf_computation(lambda: tf.constant(10))))

    add_10 = computation_building_blocks.Call(
        curried_int32_add, computation_building_blocks.Call(make_10))

    add_10_computation = computation_impl.ComputationImpl(
        add_10.proto, context_stack_impl.context_stack)

    self.assertEqual(add_10_computation(5), 15)
コード例 #7
0
 def test_basic_functionality_of_call_class(self):
     x = computation_building_blocks.Reference(
         'foo', computation_types.FunctionType(tf.int32, tf.bool))
     y = computation_building_blocks.Reference('bar', tf.int32)
     z = computation_building_blocks.Call(x, y)
     self.assertEqual(str(z.type_signature), 'bool')
     self.assertIs(z.function, x)
     self.assertIs(z.argument, y)
     self.assertEqual(
         repr(z), 'Call(Reference(\'foo\', '
         'FunctionType(TensorType(tf.int32), TensorType(tf.bool))), '
         'Reference(\'bar\', TensorType(tf.int32)))')
     self.assertEqual(z.tff_repr, 'foo(bar)')
     with self.assertRaises(TypeError):
         computation_building_blocks.Call(x)
     w = computation_building_blocks.Reference('bak', tf.float32)
     with self.assertRaises(TypeError):
         computation_building_blocks.Call(x, w)
     z_proto = z.proto
     self.assertEqual(type_serialization.deserialize_type(z_proto.type),
                      z.type_signature)
     self.assertEqual(z_proto.WhichOneof('computation'), 'call')
     self.assertEqual(str(z_proto.call.function), str(x.proto))
     self.assertEqual(str(z_proto.call.argument), str(y.proto))
     self._serialize_deserialize_roundtrip_test(z)
コード例 #8
0
  def test_execute_with_block(self):
    add_one = computation_building_blocks.ComputationBuildingBlock.from_proto(
        computation_impl.ComputationImpl.get_proto(
            computations.tf_computation(lambda x: x + 1, tf.int32)))

    make_10 = computation_building_blocks.ComputationBuildingBlock.from_proto(
        computation_impl.ComputationImpl.get_proto(
            computations.tf_computation(lambda: tf.constant(10))))

    make_13 = computation_building_blocks.Block(
        [('x', computation_building_blocks.Call(make_10)),
         ('x',
          computation_building_blocks.Call(
              add_one, computation_building_blocks.Reference('x', tf.int32))),
         ('x',
          computation_building_blocks.Call(
              add_one, computation_building_blocks.Reference('x', tf.int32))),
         ('x',
          computation_building_blocks.Call(
              add_one, computation_building_blocks.Reference('x', tf.int32)))],
        computation_building_blocks.Reference('x', tf.int32))

    make_13_computation = computation_impl.ComputationImpl(
        make_13.proto, context_stack_impl.context_stack)

    self.assertEqual(make_13_computation(), 13)
コード例 #9
0
 def transform(self, comp):
     if comp.index is not None:
         return computation_building_blocks.Call(
             select_graph_output(comp.source.function, index=comp.index),
             comp.source.argument)
     else:
         return computation_building_blocks.Call(
             select_graph_output(comp.source.function, name=comp.name),
             comp.source.argument)
コード例 #10
0
def construct_federated_getitem_call(arg, idx):
  """Calls intrinsic `ValueImpl`, passing getitem to a federated value.

  The main piece of orchestration plugging __getitem__ call together with a
  federated value.

  Args:
    arg: Instance of `computation_building_blocks.ComputationBuildingBlock` of
      `computation_types.FederatedType` with member of type
      `computation_types.NamedTupleType` from which we wish to pick out item
      `idx`.
    idx: Index, instance of `int` or `slice` used to address the
      `computation_types.NamedTupleType` underlying `arg`.

  Returns:
    Returns an instance of `ValueImpl` of type `computation_types.FederatedType`
    of same placement as `arg`, the result of applying or mapping the
    appropriate `__getitem__` function, as defined by `idx`.
  """
  py_typecheck.check_type(arg,
                          computation_building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(idx, (int, slice))
  py_typecheck.check_type(arg.type_signature, computation_types.FederatedType)
  py_typecheck.check_type(arg.type_signature.member,
                          computation_types.NamedTupleType)
  getitem_comp = construct_federated_getitem_comp(arg, idx)
  intrinsic = construct_map_or_apply(getitem_comp, arg)
  call = computation_building_blocks.Call(
      intrinsic, computation_building_blocks.Tuple([getitem_comp, arg]))
  return call
コード例 #11
0
def construct_federated_getattr_call(arg, name):
    """Constructs computation building block passing getattr to federated value.

  Args:
    arg: Instance of `computation_building_blocks.ComputationBuildingBlock` of
      `computation_types.FederatedType` with member of type
      `computation_types.NamedTupleType` from which we wish to pick out item
      `name`.
    name: String name to address the `computation_types.NamedTupleType`
      underlying `arg`.

  Returns:
    Returns a `computation_building_blocks.Call` with type signature
    `computation_types.FederatedType` of same placement as `arg`,
    the result of applying or mapping the appropriate `__getattr__` function,
    as defined by `name`.
  """
    py_typecheck.check_type(
        arg, computation_building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(name, six.string_types)
    py_typecheck.check_type(arg.type_signature,
                            computation_types.FederatedType)
    py_typecheck.check_type(arg.type_signature.member,
                            computation_types.NamedTupleType)
    getattr_comp = construct_federated_getattr_comp(arg, name)
    intrinsic = construct_map_or_apply(getattr_comp, arg)
    call = computation_building_blocks.Call(
        intrinsic, computation_building_blocks.Tuple([getattr_comp, arg]))
    return call
コード例 #12
0
def _create_lambda_to_add_one(dtype):
    r"""Creates a computation to add `1` to an argument.

  Lambda
        \
         Call
        /    \
  Intrinsic   Tuple
             /     \
    Reference       Computation

  Args:
    dtype: The type of the argument.

  Returns:
    An instance of `computation_building_blocks.Lambda` wrapping a function that
    adds 1 to an argument.
  """
    if isinstance(dtype, computation_types.TensorType):
        dtype = dtype.dtype
    py_typecheck.check_type(dtype, tf.dtypes.DType)
    function_type = computation_types.FunctionType([dtype, dtype], dtype)
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.GENERIC_PLUS.uri, function_type)
    arg = computation_building_blocks.Reference('arg', dtype)
    constant = _create_call_to_py_fn(lambda: tf.cast(tf.constant(1), dtype))
    tup = computation_building_blocks.Tuple([arg, constant])
    call = computation_building_blocks.Call(intrinsic, tup)
    return computation_building_blocks.Lambda(arg.name, arg.type_signature,
                                              call)
コード例 #13
0
def get_curried(func):
  """Returns a curried version of function `func` that takes a parameter tuple.

  For functions `func` of types <T1,T2,....,Tn> -> U, the result is a function
  of the form T1 -> (T2 -> (T3 -> .... (Tn -> U) ... )).

  NOTE: No attempt is made at avoiding naming conflicts in cases where `func`
  contains references. The arguments of the curriend function are named `argN`
  with `N` starting at 0.

  Args:
    func: A value of a functional TFF type.

  Returns:
    A value that represents the curried form of `func`.
  """
  py_typecheck.check_type(func, value_base.Value)
  py_typecheck.check_type(func.type_signature, computation_types.FunctionType)
  py_typecheck.check_type(func.type_signature.parameter,
                          computation_types.NamedTupleType)
  param_elements = anonymous_tuple.to_elements(func.type_signature.parameter)
  references = []
  for idx, (_, elem_type) in enumerate(param_elements):
    references.append(
        computation_building_blocks.Reference('arg{}'.format(idx), elem_type))
  result = computation_building_blocks.Call(
      value_impl.ValueImpl.get_comp(func),
      computation_building_blocks.Tuple(references))
  for ref in references[::-1]:
    result = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                                result)
  return value_impl.ValueImpl(result,
                              value_impl.ValueImpl.get_context_stack(func))
コード例 #14
0
    def sequence_reduce(self, value, zero, op):
        """Implements `sequence_reduce` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      zero: As in `api/intrinsics.py`.
      op: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        value = value_impl.to_value(value, None, self._context_stack)
        zero = value_impl.to_value(zero, None, self._context_stack)
        op = value_impl.to_value(op, None, self._context_stack)
        if isinstance(value.type_signature, computation_types.SequenceType):
            element_type = value.type_signature.element
        else:
            py_typecheck.check_type(value.type_signature,
                                    computation_types.FederatedType)
            py_typecheck.check_type(value.type_signature.member,
                                    computation_types.SequenceType)
            element_type = value.type_signature.member.element
        op_type_expected = type_constructors.reduction_op(
            zero.type_signature, element_type)
        if not type_utils.is_assignable_from(op_type_expected,
                                             op.type_signature):
            raise TypeError('Expected an operator of type {}, got {}.'.format(
                op_type_expected, op.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        zero = value_impl.ValueImpl.get_comp(zero)
        op = value_impl.ValueImpl.get_comp(op)
        if isinstance(value.type_signature, computation_types.SequenceType):
            return computation_constructing_utils.create_sequence_reduce(
                value, zero, op)
        else:
            value_type = computation_types.SequenceType(element_type)
            intrinsic_type = computation_types.FunctionType((
                value_type,
                zero.type_signature,
                op.type_signature,
            ), op.type_signature.result)
            intrinsic = computation_building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_REDUCE.uri, intrinsic_type)
            ref = computation_building_blocks.Reference('arg', value_type)
            tup = computation_building_blocks.Tuple((ref, zero, op))
            call = computation_building_blocks.Call(intrinsic, tup)
            fn = computation_building_blocks.Lambda(ref.name,
                                                    ref.type_signature, call)
            fn_impl = value_impl.ValueImpl(fn, self._context_stack)
            if value.type_signature.placement is placements.SERVER:
                return self.federated_apply(fn_impl, value)
            elif value.type_signature.placement is placements.CLIENTS:
                return self.federated_map(fn_impl, value)
            else:
                raise TypeError('Unsupported placement {}.'.format(
                    value.type_signature.placement))
コード例 #15
0
    def test_replace_called_lambda_replaces_called_lambda(self):
        arg = computation_building_blocks.Reference('arg', tf.int32)
        lam = _create_lambda_to_add_one(arg.type_signature)
        call = computation_building_blocks.Call(lam, arg)
        calling_lambda = computation_building_blocks.Lambda(
            arg.name, arg.type_signature, call)
        comp = calling_lambda

        self.assertEqual(
            _get_number_of_computations(comp,
                                        computation_building_blocks.Block), 0)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl(1), 2)

        transformed_comp = transformations.replace_called_lambda_with_block(
            comp)

        self.assertEqual(
            _get_number_of_computations(transformed_comp,
                                        computation_building_blocks.Call),
            _get_number_of_computations(comp, computation_building_blocks.Call)
            - 1)
        self.assertEqual(
            _get_number_of_computations(transformed_comp,
                                        computation_building_blocks.Lambda),
            _get_number_of_computations(
                comp, computation_building_blocks.Lambda) - 1)
        self.assertEqual(
            _get_number_of_computations(transformed_comp,
                                        computation_building_blocks.Block), 1)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl(1), 2)
コード例 #16
0
def _create_call_to_federated_map(fn, arg):
    r"""Creates a computation to call a federated map.

            Call
           /    \
  Intrinsic      Tuple
                /     \
     Computation       Computation

  Args:
    fn: An instance of a functional
      `computation_building_blocks.ComputationBuildingBlock` to use as the map
      function.
    arg: An instance of `computation_building_blocks.ComputationBuildingBlock`
      to use as the map argument.

  Returns:
    An instance of `computation_building_blocks.Call` wrapping the federated map
    computation.
  """
    py_typecheck.check_type(
        fn, computation_building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(
        arg, computation_building_blocks.ComputationBuildingBlock)
    federated_type = computation_types.FederatedType(fn.type_signature.result,
                                                     placements.CLIENTS)
    function_type = computation_types.FunctionType(
        [fn.type_signature, arg.type_signature], federated_type)
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.FEDERATED_MAP.uri, function_type)
    tup = computation_building_blocks.Tuple((fn, arg))
    return computation_building_blocks.Call(intrinsic, tup)
コード例 #17
0
    def test_replace_intrinsic_replaces_multiple_intrinsics(self):
        calling_arg = computation_building_blocks.Reference('arg', tf.int32)
        arg_type = calling_arg.type_signature
        arg = calling_arg
        for _ in range(10):
            lam = _create_lambda_to_add_one(arg_type)
            call = computation_building_blocks.Call(lam, arg)
            arg_type = call.function.type_signature.result
            arg = call
        calling_lambda = computation_building_blocks.Lambda(
            calling_arg.name, calling_arg.type_signature, call)
        comp = calling_lambda
        uri = intrinsic_defs.GENERIC_PLUS.uri
        body = lambda x: 100

        self.assertEqual(_get_number_of_intrinsics(comp, uri), 10)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl(1), 11)

        transformed_comp = transformations.replace_intrinsic_with_callable(
            comp, uri, body, context_stack_impl.context_stack)

        self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 0)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl(1), 100)
コード例 #18
0
def create_chained_calls(functions, arg):
  r"""Creates a chain of `n` calls.

       Call
      /    \
  Comp      ...
               \
                Call
               /    \
           Comp      Comp

  The first functional computation in `functions` must have a parameter type
  that is assignable from the type of `arg`, each other functional computation
  in `functions` must have a parameter type that is assignable from the previous
  functional computations result type.

  Args:
    functions: A Python list of functional computations.
    arg: A `computation_building_blocks.ComputationBuildingBlock`.

  Returns:
    A `computation_building_blocks.Call`.
  """
  for fn in functions:
    if not type_utils.is_assignable_from(fn.parameter_type, arg.type_signature):
      raise TypeError(
          'The parameter of the function is of type {}, and the argument is of '
          'an incompatible type {}.'.format(
              str(fn.parameter_type), str(arg.type_signature)))
    call = computation_building_blocks.Call(fn, arg)
    arg = call
  return call
コード例 #19
0
 def transform(self, comp):
     if not self.should_transform(comp):
         return comp, False
     return computation_building_blocks.Call(
         select_graph_output(comp.source.function,
                             index=comp.index,
                             name=comp.name), comp.source.argument), True
コード例 #20
0
 def test_returns_string_for_comp_with_right_overhang(self):
     ref = computation_building_blocks.Reference('a', tf.int32)
     data = computation_building_blocks.Data('data', tf.int32)
     tup = computation_building_blocks.Tuple([ref, data, data, data, data])
     sel = computation_building_blocks.Selection(tup, index=0)
     fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                             sel)
     comp = computation_building_blocks.Call(fn, data)
     compact_string = comp.compact_representation()
     self.assertEqual(compact_string,
                      '(a -> <a,data,data,data,data>[0])(data)')
     formatted_string = comp.formatted_representation()
     # pyformat: disable
     self.assertEqual(
         formatted_string, '(a -> <\n'
         '  a,\n'
         '  data,\n'
         '  data,\n'
         '  data,\n'
         '  data\n'
         '>[0])(data)')
     # pyformat: enable
     structural_string = comp.structural_representation()
     # pyformat: disable
     self.assertEqual(
         structural_string, '          Call\n'
         '         /    \\\n'
         'Lambda(a)      data\n'
         '|\n'
         'Sel(0)\n'
         '|\n'
         'Tuple\n'
         '|\n'
         '[Ref(a), data, data, data, data]')
コード例 #21
0
  def sequence_reduce(self, value, zero, op):
    """Implements `sequence_reduce` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      zero: As in `api/intrinsics.py`.
      op: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    value = value_impl.to_value(value, None, self._context_stack)
    zero = value_impl.to_value(zero, None, self._context_stack)
    op = value_impl.to_value(op, None, self._context_stack)
    if isinstance(value.type_signature, computation_types.SequenceType):
      element_type = value.type_signature.element
    else:
      py_typecheck.check_type(value.type_signature,
                              computation_types.FederatedType)
      py_typecheck.check_type(value.type_signature.member,
                              computation_types.SequenceType)
      element_type = value.type_signature.member.element
    op_type_expected = type_constructors.reduction_op(zero.type_signature,
                                                      element_type)
    if not type_utils.is_assignable_from(op_type_expected, op.type_signature):
      raise TypeError('Expected an operator of type {}, got {}.'.format(
          str(op_type_expected), str(op.type_signature)))
    sequence_reduce_building_block = computation_building_blocks.Intrinsic(
        intrinsic_defs.SEQUENCE_REDUCE.uri,
        computation_types.FunctionType([
            computation_types.SequenceType(element_type), zero.type_signature,
            op.type_signature
        ], zero.type_signature))
    if isinstance(value.type_signature, computation_types.SequenceType):
      sequence_reduce_intrinsic = value_impl.ValueImpl(
          sequence_reduce_building_block, self._context_stack)
      return sequence_reduce_intrinsic(value, zero, op)
    else:
      federated_mapping_fn_building_block = computation_building_blocks.Lambda(
          'arg', computation_types.SequenceType(element_type),
          computation_building_blocks.Call(
              sequence_reduce_building_block,
              computation_building_blocks.Tuple([
                  computation_building_blocks.Reference(
                      'arg', computation_types.SequenceType(element_type)),
                  value_impl.ValueImpl.get_comp(zero),
                  value_impl.ValueImpl.get_comp(op)
              ])))
      federated_mapping_fn = value_impl.ValueImpl(
          federated_mapping_fn_building_block, self._context_stack)
      if value.type_signature.placement is placements.SERVER:
        return self.federated_apply(federated_mapping_fn, value)
      elif value.type_signature.placement is placements.CLIENTS:
        return self.federated_map(federated_mapping_fn, value)
      else:
        raise TypeError('Unsupported placement {}.'.format(
            str(value.type_signature.placement)))
コード例 #22
0
def _create_lambda_to_dummy_intrinsic(type_spec, uri='dummy'):
    r"""Creates a lambda to call a dummy intrinsic.

  Lambda
        \
         Call
        /    \
  Intrinsic   Ref(arg)

  Args:
    type_spec: The type of the argument.
    uri: The URI of the intrinsic.

  Returns:
    A `computation_building_blocks.Lambda`.

  Raises:
    TypeError: If `type_spec` is not a `tf.dtypes.DType`.
  """
    py_typecheck.check_type(type_spec, tf.dtypes.DType)
    intrinsic_type = computation_types.FunctionType(type_spec, type_spec)
    intrinsic = computation_building_blocks.Intrinsic(uri, intrinsic_type)
    arg = computation_building_blocks.Reference('arg', type_spec)
    call = computation_building_blocks.Call(intrinsic, arg)
    return computation_building_blocks.Lambda(arg.name, arg.type_signature,
                                              call)
コード例 #23
0
def create_sequence_map(fn, arg):
    r"""Creates a called sequence map.

            Call
           /    \
  Intrinsic      Tuple
                 |
                 [Comp, Comp]

  Args:
    fn: A `computation_building_blocks.ComputationBuildingBlock` to use as the
      function.
    arg: A `computation_building_blocks.ComputationBuildingBlock` to use as the
      argument.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        fn, computation_building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(
        arg, computation_building_blocks.ComputationBuildingBlock)
    result_type = computation_types.SequenceType(fn.type_signature.result)
    intrinsic_type = computation_types.FunctionType(
        (fn.type_signature, arg.type_signature), result_type)
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type)
    values = computation_building_blocks.Tuple((fn, arg))
    return computation_building_blocks.Call(intrinsic, values)
コード例 #24
0
def construct_map_or_apply(fn, arg):
  """Injects intrinsic to allow application of `fn` to federated `arg`.

  Args:
    fn: An instance of `computation_building_blocks.ComputationBuildingBlock` of
      functional type to be wrapped with intrinsic in order to call on `arg`.
    arg: `computation_building_blocks.ComputationBuildingBlock` instance of
      federated type for which to construct intrinsic in order to call `fn` on
      `arg`. `member` of `type_signature` of `arg` must be assignable to
      `parameter` of `type_signature` of `fn`.

  Returns:
    Returns a `computation_building_blocks.Intrinsic` which can call
    `fn` on `arg`.
  """
  py_typecheck.check_type(fn,
                          computation_building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(arg,
                          computation_building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)
  py_typecheck.check_type(arg.type_signature, computation_types.FederatedType)
  type_utils.check_assignable_from(fn.type_signature.parameter,
                                   arg.type_signature.member)
  if arg.type_signature.placement == placement_literals.SERVER:
    result_type = computation_types.FederatedType(fn.type_signature.result,
                                                  arg.type_signature.placement,
                                                  arg.type_signature.all_equal)
    intrinsic_type = computation_types.FunctionType(
        [fn.type_signature, arg.type_signature], result_type)
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.FEDERATED_APPLY.uri, intrinsic_type)
    tup = computation_building_blocks.Tuple((fn, arg))
    return computation_building_blocks.Call(intrinsic, tup)
  elif arg.type_signature.placement == placement_literals.CLIENTS:
    return create_federated_map(fn, arg)
コード例 #25
0
def create_federated_value(value, placement):
    r"""Creates a called federated value.

            Call
           /    \
  Intrinsic      Comp

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` to use as
      the value.
    placement: A `placement_literals.PlacementLiteral` to use as the placement.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    if placement is placement_literals.CLIENTS:
        uri = intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri
    elif placement is placement_literals.SERVER:
        uri = intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri
    else:
        raise TypeError('Unsupported placement {}.'.format(placement))
    result_type = computation_types.FederatedType(value.type_signature,
                                                  placement, True)
    intrinsic_type = computation_types.FunctionType(value.type_signature,
                                                    result_type)
    intrinsic = computation_building_blocks.Intrinsic(uri, intrinsic_type)
    return computation_building_blocks.Call(intrinsic, value)
コード例 #26
0
def create_federated_sum(value):
    r"""Creates a called federated sum.

            Call
           /    \
  Intrinsic      Comp

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` to use as
      the value.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    result_type = computation_types.FederatedType(value.type_signature.member,
                                                  placement_literals.SERVER,
                                                  True)
    intrinsic_type = computation_types.FunctionType(value.type_signature,
                                                    result_type)
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.FEDERATED_SUM.uri, intrinsic_type)
    return computation_building_blocks.Call(intrinsic, value)
コード例 #27
0
def create_federated_map(fn, arg):
    r"""Creates a called federated map.

            Call
           /    \
  Intrinsic      Tuple
                 |
                 [Comp, Comp]

  Args:
    fn: A `computation_building_blocks.ComputationBuildingBlock` to use as the
      function.
    arg: A `computation_building_blocks.ComputationBuildingBlock` to use as the
      argument.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        fn, computation_building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(
        arg, computation_building_blocks.ComputationBuildingBlock)
    result_type = computation_types.FederatedType(fn.type_signature.result,
                                                  placement_literals.CLIENTS,
                                                  False)
    intrinsic_type = computation_types.FunctionType(
        (fn.type_signature, arg.type_signature), result_type)
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.FEDERATED_MAP.uri, intrinsic_type)
    values = computation_building_blocks.Tuple((fn, arg))
    return computation_building_blocks.Call(intrinsic, values)
コード例 #28
0
def _create_zip_two_values(value):
    r"""Creates a called federated zip with two values.

            Call
           /    \
  Intrinsic      Tuple
                 |
                 [Comp1, Comp2]

  Notice that this function will drop any names associated to the two-tuple it
  is processing. This is necessary due to the type signature of the
  underlying federated zip intrinsic, `<T@P,U@P>-><T,U>@P`. Keeping names
  here would violate this type signature. The names are cached at a higher
  level than this function, and appended to the resulting tuple in a single
  call to `federated_map` or `federated_apply` before the resulting structure
  is sent back to the caller.

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_types.NamedTupleType` containing
      exactly two elements.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
    ValueError: If `value` does not contain exactly two elements.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    named_type_signatures = anonymous_tuple.to_elements(value.type_signature)
    length = len(named_type_signatures)
    if length != 2:
        raise ValueError(
            'Expected a value with exactly two elements, received {} elements.'
            .format(named_type_signatures))
    placement = value[0].type_signature.placement
    if placement is placement_literals.CLIENTS:
        uri = intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri
        all_equal = False
    elif placement is placement_literals.SERVER:
        uri = intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri
        all_equal = True
    else:
        raise TypeError('Unsupported placement {}.'.format(placement))
    elements = []
    for _, type_signature in named_type_signatures:
        federated_type = computation_types.FederatedType(
            type_signature.member, placement, all_equal)
        elements.append((None, federated_type))
    parameter_type = computation_types.NamedTupleType(elements)
    result_type = computation_types.FederatedType(
        [(None, e.member) for _, e in named_type_signatures], placement,
        all_equal)
    intrinsic_type = computation_types.FunctionType(parameter_type,
                                                    result_type)
    intrinsic = computation_building_blocks.Intrinsic(uri, intrinsic_type)
    return computation_building_blocks.Call(intrinsic, value)
コード例 #29
0
    def _transform(comp):
        """Returns a new transformed computation or `comp`."""
        if not _should_transform(comp):
            return comp, False

        def _create_block_to_chained_calls(comps):
            r"""Constructs a transformed block computation from `comps`.

                     Block
                    /     \
          [fn=Tuple]       Lambda(arg)
              |                       \
      [Comp(y), Comp(x)]               Call
                                      /    \
                                Sel(1)      Call
                               /           /    \
                        Ref(fn)      Sel(0)      Ref(arg)
                                    /
                             Ref(fn)

      (let fn=<y, x> in (arg -> fn[1](fn[0](arg)))

      Args:
        comps: a Python list of computations.

      Returns:
        A `computation_building_blocks.Block`.
      """
            functions = computation_building_blocks.Tuple(comps)
            fn_ref = computation_building_blocks.Reference(
                'fn', functions.type_signature)
            arg_type = comps[0].type_signature.parameter
            arg_ref = computation_building_blocks.Reference('arg', arg_type)
            arg = arg_ref
            for index, _ in enumerate(comps):
                fn_sel = computation_building_blocks.Selection(fn_ref,
                                                               index=index)
                call = computation_building_blocks.Call(fn_sel, arg)
                arg = call
            lam = computation_building_blocks.Lambda(arg_ref.name,
                                                     arg_ref.type_signature,
                                                     call)
            return computation_building_blocks.Block([('fn', functions)], lam)

        block = _create_block_to_chained_calls((
            comp.argument[1].argument[0],
            comp.argument[0],
        ))
        arg = computation_building_blocks.Tuple([
            block,
            comp.argument[1].argument[1],
        ])
        intrinsic_type = computation_types.FunctionType(
            arg.type_signature, comp.function.type_signature.result)
        intrinsic = computation_building_blocks.Intrinsic(
            comp.function.uri, intrinsic_type)
        transformed_comp = computation_building_blocks.Call(intrinsic, arg)
        return transformed_comp, True
コード例 #30
0
def construct_binary_operator_with_upcast(type_signature, operator):
  """Constructs lambda upcasting its argument and applying `operator`.

  The concept of upcasting is explained further in the docstring for
  `apply_binary_operator_with_upcast`.

  Notice that since we are constructing a function here, e.g. for the body
  of an intrinsic, the function we are constructing must be reducible to
  TensorFlow. Therefore `type_signature` can only have named tuple or tensor
  type elements; that is, we cannot handle federated types here in a generic
  way.

  Args:
    type_signature: Value convertible to `computation_types.NamedTupleType`,
      with two elements, both of the same type or the second able to be upcast
      to the first, as explained in `apply_binary_operator_with_upcast`, and
      both containing only tuples and tensors in their type tree.
    operator: Callable defining the operator.

  Returns:
    A `computation_building_blocks.Lambda` encapsulating a function which
    upcasts the second element of its argument and applies the binary
    operator.
  """
  py_typecheck.check_callable(operator)
  type_signature = computation_types.to_type(type_signature)
  _check_generic_operator_type(type_signature)
  ref_to_arg = computation_building_blocks.Reference('binary_operator_arg',
                                                     type_signature)

  def _pack_into_type(to_pack, type_spec):
    """Pack Tensor value `to_pack` into the nested structure `type_spec`."""
    if isinstance(type_spec, computation_types.NamedTupleType):
      elems = anonymous_tuple.to_elements(type_spec)
      packed_elems = [(elem_name, _pack_into_type(to_pack, elem_type))
                      for elem_name, elem_type in elems]
      return computation_building_blocks.Tuple(packed_elems)
    elif isinstance(type_spec, computation_types.TensorType):
      expand_fn = computation_constructing_utils.construct_tensorflow_to_broadcast_scalar(
          to_pack.type_signature.dtype, type_spec.shape)
      return computation_building_blocks.Call(expand_fn, to_pack)

  y_ref = computation_building_blocks.Selection(ref_to_arg, index=1)
  first_arg = computation_building_blocks.Selection(ref_to_arg, index=0)

  if type_utils.are_equivalent_types(first_arg.type_signature,
                                     y_ref.type_signature):
    second_arg = y_ref
  else:
    second_arg = _pack_into_type(y_ref, first_arg.type_signature)

  fn = computation_constructing_utils.construct_tensorflow_binary_operator(
      first_arg.type_signature, operator)
  packed = computation_building_blocks.Tuple([first_arg, second_arg])
  operated = computation_building_blocks.Call(fn, packed)
  lambda_encapsulating_op = computation_building_blocks.Lambda(
      ref_to_arg.name, ref_to_arg.type_signature, operated)
  return lambda_encapsulating_op