示例#1
0
 def test_child_name_doesnt_conflict(self):
     context = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack, suggested_name='FOO')
     self.assertEqual(context.name, 'FOO')
     context2 = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack,
         suggested_name='FOO',
         parent=context)
     self.assertEqual(context2.name, 'FOO_1')
     context3 = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack,
         suggested_name='FOO',
         parent=context2)
     self.assertEqual(context3.name, 'FOO_2')
    def test_ingest_zips_federated_under_struct(self):
        @federated_computation.federated_computation(
            computation_types.StructType([
                (None,
                 collections.OrderedDict(
                     x=computation_types.at_clients(tf.int32),
                     y=computation_types.at_clients(tf.int32)))
            ]))
        def fn(_):
            return ()

        arg = building_blocks.Struct([
            building_blocks.Struct([
                building_blocks.Reference(
                    'x',
                    computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)),
                building_blocks.Reference(
                    'y',
                    computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS))
            ])
        ])

        context = federated_computation_context.FederatedComputationContext(
            context_stack_impl.context_stack)
        with context_stack_impl.context_stack.install(context):
            fn(arg)
示例#3
0
 def test_invoke_returns_value_with_correct_type(self):
     context = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack)
     comp = computations.tf_computation(lambda: tf.constant(10))
     with context_stack_impl.context_stack.install(context):
         result = context.invoke(comp, None)
     self.assertIsInstance(result, value_impl.Value)
     self.assertEqual(str(result.type_signature), 'int32')
示例#4
0
def federated_computation_serializer(
    parameter_name: Optional[str],
    parameter_type: Optional[computation_types.Type],
    context_stack: context_stack_base.ContextStack,
    suggested_name: Optional[str] = None,
):
    """Converts a function into a computation building block.

  Args:
    parameter_name: The name of the parameter, or `None` if there is't any.
    parameter_type: The `tff.Type` of the parameter, or `None` if there's none.
    context_stack: The context stack to use.
    suggested_name: The optional suggested name to use for the federated context
      that will be used to serialize this function's body (ideally the name of
      the underlying Python function). It might be modified to avoid conflicts.

  Yields:
    First, the argument to be passed to the function to be converted.
    Finally, a tuple of `(building_blocks.ComputationBuildingBlock,
    computation_types.Type)`: the function represented via building blocks and
    the inferred return type.
  """
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if suggested_name is not None:
        py_typecheck.check_type(suggested_name, str)
    if isinstance(context_stack.current,
                  federated_computation_context.FederatedComputationContext):
        parent_context = context_stack.current
    else:
        parent_context = None
    context = federated_computation_context.FederatedComputationContext(
        context_stack, suggested_name=suggested_name, parent=parent_context)
    if parameter_name is not None:
        py_typecheck.check_type(parameter_name, str)
        parameter_name = '{}_{}'.format(context.name, str(parameter_name))
    with context_stack.install(context):
        if parameter_type is None:
            result = yield None
        else:
            result = yield (value_impl.ValueImpl(
                building_blocks.Reference(parameter_name, parameter_type),
                context_stack))
        annotated_result_type = type_conversions.infer_type(result)
        result = value_impl.to_value(result, annotated_result_type,
                                     context_stack)
        result_comp = value_impl.ValueImpl.get_comp(result)
        symbols_bound_in_context = context_stack.current.symbol_bindings
        if symbols_bound_in_context:
            result_comp = building_blocks.Block(
                local_symbols=symbols_bound_in_context, result=result_comp)
        annotated_type = computation_types.FunctionType(
            parameter_type, annotated_result_type)
        yield building_blocks.Lambda(parameter_name, parameter_type,
                                     result_comp), annotated_type
示例#5
0
    def test_bind_single_computation_to_reference(self):
        context = federated_computation_context.FederatedComputationContext(
            context_stack_impl.context_stack)
        data = building_blocks.Data('x', tf.int32)
        ref = context.bind_computation_to_reference(data)
        symbol_bindings = context.symbol_bindings
        bound_symbol_name = symbol_bindings[0][0]

        self.assertIsInstance(ref, building_blocks.Reference)
        self.assertEqual(ref.type_signature, data.type_signature)
        self.assertLen(symbol_bindings, 1)
        self.assertEqual(bound_symbol_name, ref.name)
    def test_invoke_returns_value_with_correct_type(self):
        tensor_type = computation_types.TensorType(tf.int32)
        computation_proto, _ = tensorflow_computation_factory.create_constant(
            10, tensor_type)
        computation = computation_impl.ConcreteComputation(
            computation_proto, context_stack_impl.context_stack)
        context = federated_computation_context.FederatedComputationContext(
            context_stack_impl.context_stack)

        with context_stack_impl.context_stack.install(context):
            result = context.invoke(computation, None)

        self.assertIsInstance(result, value_impl.Value)
        self.assertEqual(str(result.type_signature), 'int32')
示例#7
0
    def test_bind_two_computations_to_reference(self):
        context = federated_computation_context.FederatedComputationContext(
            context_stack_impl.context_stack)
        data = building_blocks.Data('x', tf.int32)
        float_data = building_blocks.Data('x', tf.float32)
        ref1 = context.bind_computation_to_reference(data)
        ref2 = context.bind_computation_to_reference(float_data)
        symbol_bindings = context.symbol_bindings

        self.assertIsInstance(ref1, building_blocks.Reference)
        self.assertIsInstance(ref2, building_blocks.Reference)

        self.assertEqual(ref1.type_signature, data.type_signature)
        self.assertEqual(ref2.type_signature, float_data.type_signature)
        self.assertLen(symbol_bindings, 2)
        self.assertEqual(symbol_bindings[0][0], ref1.name)
        self.assertEqual(symbol_bindings[1][0], ref2.name)
    def test_ingest_zips_value_when_necessary_to_match_federated_type(self):
        # Expects `{<int, int>}@C`
        @federated_computation.federated_computation(
            computation_types.at_clients((tf.int32, tf.int32)))
        def fn(_):
            return ()

        # This thing will be <{int}@C, {int}@C>
        arg = building_blocks.Struct([
            building_blocks.Reference(
                'x',
                computation_types.FederatedType(tf.int32, placements.CLIENTS)),
            building_blocks.Reference(
                'y',
                computation_types.FederatedType(tf.int32, placements.CLIENTS))
        ])

        context = federated_computation_context.FederatedComputationContext(
            context_stack_impl.context_stack)
        with context_stack_impl.context_stack.install(context):
            fn(arg)
 def run(self, result=None):
     fc_context = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack)
     with context_stack_impl.context_stack.install(fc_context):
         super(FederatedSecureSumTest, self).run(result)
      return factory.sequence_reduce(value, 0, add)

    self.assertEqual(foo.type_signature.compact_representation(),
                     '(int32* -> int32)')

  def test_type_signature_with_federated_type(self):
    factory = intrinsic_factory.IntrinsicFactory(
        context_stack_impl.context_stack)

    @computations.tf_computation(np.int32, np.int32)
    def add(x, y):
      return x + y

    @computations.federated_computation(
        computation_types.FederatedType(
            computation_types.SequenceType(np.int32),
            placement_literals.CLIENTS))
    def foo(value):
      zero = intrinsics.federated_value(0, placement_literals.CLIENTS)
      return factory.sequence_reduce(value, zero, add)

    self.assertEqual(foo.type_signature.compact_representation(),
                     '({int32*}@CLIENTS -> {int32}@CLIENTS)')


if __name__ == '__main__':
  context = federated_computation_context.FederatedComputationContext(
      context_stack_impl.context_stack)
  with context_stack_impl.context_stack.install(context):
    absltest.main()
示例#11
0
 def test_parent_populated_correctly(self):
     context = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack)
     context2 = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack, parent=context)
     self.assertIs(context2.parent, context)
示例#12
0
 def test_suggested_name_populates_name_attribute(self):
     context = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack, suggested_name='FOO')
     self.assertEqual(context.name, 'FOO')
示例#13
0
 def test_construction_populates_name(self):
     context = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack)
     self.assertEqual(context.name, 'FEDERATED')
示例#14
0
def zero_or_one_arg_fn_to_building_block(
    fn,
    parameter_name: Optional[str],
    parameter_type: Optional[computation_types.Type],
    context_stack: context_stack_base.ContextStack,
    suggested_name: Optional[str] = None,
) -> Tuple[building_blocks.ComputationBuildingBlock, computation_types.Type]:
    """Converts a zero- or one-argument `fn` into a computation building block.

  Args:
    fn: A function with 0 or 1 arguments that contains orchestration logic,
      i.e., that expects zero or one `values_base.Value` and returns a result
      convertible to the same.
    parameter_name: The name of the parameter, or `None` if there is't any.
    parameter_type: The `tff.Type` of the parameter, or `None` if there's none.
    context_stack: The context stack to use.
    suggested_name: The optional suggested name to use for the federated context
      that will be used to serialize this function's body (ideally the name of
      the underlying Python function). It might be modified to avoid conflicts.

  Returns:
    A tuple of `(building_blocks.ComputationBuildingBlock,
    computation_types.Type)`, where the first element contains the logic from
    `fn`, and the second element contains potentially annotated type information
    for the result of `fn`.

  Raises:
    ValueError: if `fn` is incompatible with `parameter_type`.
  """
    py_typecheck.check_callable(fn)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if suggested_name is not None:
        py_typecheck.check_type(suggested_name, str)
    if isinstance(context_stack.current,
                  federated_computation_context.FederatedComputationContext):
        parent_context = context_stack.current
    else:
        parent_context = None
    context = federated_computation_context.FederatedComputationContext(
        context_stack, suggested_name=suggested_name, parent=parent_context)
    if parameter_name is not None:
        py_typecheck.check_type(parameter_name, str)
        parameter_name = '{}_{}'.format(context.name, str(parameter_name))
    with context_stack.install(context):
        if parameter_type is not None:
            result = fn(
                value_impl.ValueImpl(
                    building_blocks.Reference(parameter_name, parameter_type),
                    context_stack))
        else:
            result = fn()
        if result is None:
            raise ValueError(
                'The function defined on line {} of file {} has returned a '
                '`NoneType`, but all TFF functions must return some non-`None` '
                'value.'.format(fn.__code__.co_firstlineno,
                                fn.__code__.co_filename))
        annotated_result_type = type_conversions.infer_type(result)
        result = value_impl.to_value(result, annotated_result_type,
                                     context_stack)
        result_comp = value_impl.ValueImpl.get_comp(result)
        symbols_bound_in_context = context_stack.current.symbol_bindings
        if symbols_bound_in_context:
            result_comp = building_blocks.Block(
                local_symbols=symbols_bound_in_context, result=result_comp)
        annotated_type = computation_types.FunctionType(
            parameter_type, annotated_result_type)
        return building_blocks.Lambda(parameter_name, parameter_type,
                                      result_comp), annotated_type