예제 #1
0
 def test_child_name_doesnt_conflict(self):
     context = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack, suggested_name='FOO')
     self.assertEqual(context.name, 'FOO')
     context2 = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack,
         suggested_name='FOO',
         parent=context)
     self.assertEqual(context2.name, 'FOO_1')
     context3 = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack,
         suggested_name='FOO',
         parent=context2)
     self.assertEqual(context3.name, 'FOO_2')
def zero_or_one_arg_fn_to_building_block(
    fn,
    parameter_name: Optional[str],
    parameter_type: Optional[Any],
    context_stack: context_stack_base.ContextStack,
    suggested_name: Optional[str] = None,
) -> building_blocks.ComputationBuildingBlock:
    """Converts a zero- or one-argument `fn` into a computation building block.

  Args:
    fn: A function with 0 or 1 arguments that contains orchestration logic,
      i.e., that expects zero or one `values_base.Value` and returns a result
      convertible to the same.
    parameter_name: The name of the parameter, or `None` if there is't any.
    parameter_type: The TFF type of the parameter, or `None` if there's none.
    context_stack: The context stack to use.
    suggested_name: The optional suggested name to use for the federated context
      that will be used to serialize this function's body (ideally the name of
      the underlying Python function). It might be modified to avoid conflicts.

  Returns:
    An instance of `building_blocks.ComputationBuildingBlock` that
    contains the logic from `fn`.

  Raises:
    ValueError: if `fn` is incompatible with `parameter_type`.
  """
    py_typecheck.check_callable(fn)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if suggested_name is not None:
        py_typecheck.check_type(suggested_name, str)
    parameter_type = computation_types.to_type(parameter_type)
    if isinstance(context_stack.current,
                  federated_computation_context.FederatedComputationContext):
        parent_context = context_stack.current
    else:
        parent_context = None
    context = federated_computation_context.FederatedComputationContext(
        context_stack, suggested_name=suggested_name, parent=parent_context)
    if parameter_name is not None:
        py_typecheck.check_type(parameter_name, str)
        parameter_name = '{}_{}'.format(context.name, str(parameter_name))
    with context_stack.install(context):
        if parameter_type is not None:
            result = fn(
                value_impl.ValueImpl(
                    building_blocks.Reference(parameter_name, parameter_type),
                    context_stack))
        else:
            result = fn()
        if result is None:
            raise ValueError(
                'The function defined on line {} of file {} has returned a '
                '`NoneType`, but all TFF functions must return some non-`None` '
                'value.'.format(fn.__code__.co_firstlineno,
                                fn.__code__.co_filename))
        result = value_impl.to_value(result, None, context_stack)
        result_comp = value_impl.ValueImpl.get_comp(result)
        return building_blocks.Lambda(parameter_name, parameter_type,
                                      result_comp)
예제 #3
0
 def test_invoke_returns_value_with_correct_type(self):
     context = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack)
     comp = computations.tf_computation(lambda: tf.constant(10))
     result = context.invoke(comp, None)
     self.assertIsInstance(result, value_base.Value)
     self.assertEqual(str(result.type_signature), 'int32')
예제 #4
0
def zero_or_one_arg_func_to_building_block(func,
                                           parameter_name,
                                           parameter_type,
                                           context_stack,
                                           suggested_name=None):
    """Converts a zero- or one-argument `func` into a computation building block.

  Args:
    func: A function with 0 or 1 arguments that contains orchestration logic,
      i.e., that expects zero or one `values_base.Value` and returns a result
      convertible to the same.
    parameter_name: The name of the parameter, or `None` if there is't any.
    parameter_type: The TFF type of the parameter, or `None` if there's none.
    context_stack: The context stack to use.
    suggested_name: The optional suggested name to use for the federated context
      that will be used to serialize this function's body (ideally the name of
      the underlying Python function). It might be modified to avoid conflicts.
      If not `None`, it must be a string.

  Returns:
    An instance of `computation_building_blocks.ComputationBuildingBlock` that
    contains the logic from `func`.

  Raises:
    ValueError: if `func` is incompatible with `parameter_type`.
  """
    py_typecheck.check_callable(func)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if suggested_name is not None:
        py_typecheck.check_type(suggested_name, six.string_types)
    parameter_type = computation_types.to_type(parameter_type)
    if isinstance(context_stack.current,
                  federated_computation_context.FederatedComputationContext):
        parent_context = context_stack.current
    else:
        parent_context = None
    context = federated_computation_context.FederatedComputationContext(
        context_stack, suggested_name=suggested_name, parent=parent_context)
    if parameter_name is not None:
        py_typecheck.check_type(parameter_name, six.string_types)
        parameter_name = '{}_{}'.format(context.name, str(parameter_name))
    with context_stack.install(context):
        if parameter_type is not None:
            result = func(
                value_impl.ValueImpl(
                    computation_building_blocks.Reference(
                        parameter_name, parameter_type), context_stack))
        else:
            result = func()
        result = value_impl.to_value(result, None, context_stack)
        result_comp = value_impl.ValueImpl.get_comp(result)
        if parameter_type is None:
            return result_comp
        else:
            return computation_building_blocks.Lambda(parameter_name,
                                                      parameter_type,
                                                      result_comp)
예제 #5
0
 def test_federated_value_raw_np_scalar(self):
     with context_stack_impl.context_stack.install(
             federated_computation_context.FederatedComputationContext(
                 context_stack_impl.context_stack)):
         floatv = np.float64(0)
         tff_float = intrinsics.federated_value(floatv, placements.SERVER)
         self.assertEqual(str(tff_float.type_signature), 'float64@SERVER')
         intv = np.int64(0)
         tff_int = intrinsics.federated_value(intv, placements.SERVER)
         self.assertEqual(str(tff_int.type_signature), 'int64@SERVER')
예제 #6
0
    def test_bind_single_computation_to_reference(self):
        context = federated_computation_context.FederatedComputationContext(
            context_stack_impl.context_stack)
        data = building_blocks.Data('x', tf.int32)
        ref = context.bind_computation_to_reference(data)
        symbol_bindings = context.symbol_bindings
        bound_symbol_name = symbol_bindings[0][0]

        self.assertIsInstance(ref, building_blocks.Reference)
        self.assertEqual(ref.type_signature, data.type_signature)
        self.assertLen(symbol_bindings, 1)
        self.assertEqual(bound_symbol_name, ref.name)
 def test_something(self):
     context = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack)
     comp = computations.tf_computation(lambda: tf.constant(10))
     result = context.invoke(comp, None)
     self.assertIsInstance(result, value_base.Value)
     self.assertEqual(str(result.type_signature), 'int32')
     self.assertEqual(context.name, 'FEDERATED')
     context2 = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack,
         suggested_name='FOO',
         parent=context)
     self.assertEqual(context2.name, 'FOO')
     context3 = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack,
         suggested_name='FOO',
         parent=context2)
     self.assertEqual(context3.name, 'FOO_1')
     context4 = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack,
         suggested_name='FOO',
         parent=context3)
     self.assertEqual(context4.name, 'FOO_2')
예제 #8
0
    def test_bind_two_computations_to_reference(self):
        context = federated_computation_context.FederatedComputationContext(
            context_stack_impl.context_stack)
        data = building_blocks.Data('x', tf.int32)
        float_data = building_blocks.Data('x', tf.float32)
        ref1 = context.bind_computation_to_reference(data)
        ref2 = context.bind_computation_to_reference(float_data)
        symbol_bindings = context.symbol_bindings

        self.assertIsInstance(ref1, building_blocks.Reference)
        self.assertIsInstance(ref2, building_blocks.Reference)

        self.assertEqual(ref1.type_signature, data.type_signature)
        self.assertEqual(ref2.type_signature, float_data.type_signature)
        self.assertLen(symbol_bindings, 2)
        self.assertEqual(symbol_bindings[0][0], ref1.name)
        self.assertEqual(symbol_bindings[1][0], ref2.name)
예제 #9
0
        computation_types.NamedTupleType(
            (computation_types.FederatedType(tf.int32,
                                             placement_literals.CLIENTS),
             computation_types.FederatedType(tf.int32,
                                             placement_literals.CLIENTS))))
    def _(x):
      x = value_impl.to_value(x, None, _context_stack)
      value_utils.ensure_federated_value(x)
      return x

  def test_ensure_federated_value_fails_on_unzippable(self):

    @computations.federated_computation(
        computation_types.NamedTupleType(
            (computation_types.FederatedType(tf.int32,
                                             placement_literals.CLIENTS),
             computation_types.FederatedType(tf.int32,
                                             placement_literals.SERVER))))
    def _(x):
      x = value_impl.to_value(x, None, _context_stack)
      with self.assertRaises(TypeError):
        value_utils.ensure_federated_value(x)
      return x


if __name__ == '__main__':
  with context_stack_impl.context_stack.install(
      federated_computation_context.FederatedComputationContext(
          context_stack_impl.context_stack)):
    absltest.main()
예제 #10
0
 def run(self, result=None):
     fc_context = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack)
     with context_stack_impl.context_stack.install(fc_context):
         super(ValueImplTest, self).run(result)
예제 #11
0
 def test_parent_populated_correctly(self):
     context = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack)
     context2 = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack, parent=context)
     self.assertIs(context2.parent, context)
예제 #12
0
 def test_suggested_name_populates_name_attribute(self):
     context = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack, suggested_name='FOO')
     self.assertEqual(context.name, 'FOO')
예제 #13
0
 def test_construction_populates_name(self):
     context = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack)
     self.assertEqual(context.name, 'FEDERATED')