コード例 #1
0
 def test_throws_on_unresolvable_function_call(self):
     comp = building_blocks.Call(
         building_blocks.Data(
             'unknown_func',
             computation_types.FunctionType(
                 None, computation_types.at_clients(tf.int32))))
     with self.assertRaises(ValueError):
         tree_analysis.find_unsecure_aggregation_in_tree(comp)
     with self.assertRaises(ValueError):
         tree_analysis.find_secure_aggregation_in_tree(comp)
コード例 #2
0
    def test_throws_on_unresolvable_function_call(self):
        input_ty = ()
        output_ty = computation_types.FederatedType(tf.int32,
                                                    placement_literals.CLIENTS)

        @computations.federated_computation(
            computation_types.FunctionType(input_ty, output_ty))
        def comp(unknown_func):
            return unknown_func(())

        with self.assertRaises(ValueError):
            tree_analysis.find_unsecure_aggregation_in_tree(
                comp.to_building_block())
        with self.assertRaises(ValueError):
            tree_analysis.find_secure_aggregation_in_tree(
                comp.to_building_block())
コード例 #3
0
 def test_returns_none(self, comp):
     self.assertEmpty(
         tree_analysis.find_unsecure_aggregation_in_tree(
             comp.to_building_block()))
     self.assertEmpty(
         tree_analysis.find_secure_aggregation_in_tree(
             comp.to_building_block()))
コード例 #4
0
  def test_returns_none_on_unresolvable_function_call_with_non_federated_output(
      self):
    input_type = computation_types.FederatedType(tf.int32, placements.CLIENTS)
    output_type = tf.int32
    comp = building_blocks.Call(
        building_blocks.Data(
            'unknown_func',
            computation_types.FunctionType(input_type, output_type)),
        building_blocks.Data('client_data', input_type))

    self.assertEmpty(tree_analysis.find_unsecure_aggregation_in_tree(comp))
    self.assertEmpty(tree_analysis.find_secure_aggregation_in_tree(comp))
コード例 #5
0
    def test_returns_none_on_unresolvable_function_call_with_non_federated_output(
            self):
        input_ty = computation_types.FederatedType(tf.int32,
                                                   placement_literals.CLIENTS)
        output_ty = tf.int32

        @computations.federated_computation(
            computation_types.FunctionType(input_ty, output_ty))
        def comp(unknown_func):
            return unknown_func(
                intrinsics.federated_value(1, placement_literals.CLIENTS))

        self.assertEmpty(
            tree_analysis.find_unsecure_aggregation_in_tree(
                comp.to_building_block()))
        self.assertEmpty(
            tree_analysis.find_secure_aggregation_in_tree(
                comp.to_building_block()))
コード例 #6
0
ファイル: static_assert.py プロジェクト: oodunsi1/federated-1
def assert_not_contains_unsecure_aggregation(comp):
    """Asserts that `comp` contains no unsecure aggregation calls.

  Args:
    comp: A `tff.Computation`, often a function annotated with
      `tff.federated_computation` or `tff.tf_computation`. Note that polymorphic
      functions (those without the types of their arguments explicitly
      specified) will not yet be `tff.Computation`s.

  Raises:
    AssertionError if `comp` contains an unsecure aggregation call.
    ValueError if `comp` contains a call whose target function cannot be
      identified. This may result from calls to references or other
      indirect structures.
  """
    py_typecheck.check_type(comp, computation_impl.ComputationImpl)
    comp = comp.to_building_block()
    calls = tree_analysis.find_unsecure_aggregation_in_tree(comp)
    if len(calls) != 0:  # pylint: disable=g-explicit-length-test
        _raise_expected_none(calls, 'unsecure')
コード例 #7
0
 def test_returns_one_on_unsecure_aggregation(self, comp):
     self.assertLen(tree_analysis.find_unsecure_aggregation_in_tree(comp),
                    1)
コード例 #8
0
 def test_returns_none_on_secure_aggregation(self):
     self.assertEmpty(
         tree_analysis.find_unsecure_aggregation_in_tree(simple_secure_sum))