def test_next_state_not_assignable_tuple_result(self):
     float_next_fn = computations.tf_computation(
         tf.float32,
         tf.float32)(lambda state, x: (tf.cast(state, tf.float32), x))
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         estimation_process.EstimationProcess(test_initialize_fn,
                                              float_next_fn, test_report_fn)
 def test_next_not_tff_computation_raises(self):
     with self.assertRaisesRegex(TypeError,
                                 r'Expected .*\.Computation, .*'):
         estimation_process.EstimationProcess(
             initialize_fn=test_initialize_fn,
             next_fn=lambda state: state,
             report_fn=test_report_fn)
 def test_map_estimate_not_assignable(self):
     map_fn = computations.tf_computation(
         tf.int32)(lambda estimate: estimate)
     process = estimation_process.EstimationProcess(test_initialize_fn,
                                                    test_next_fn,
                                                    test_report_fn)
     with self.assertRaises(estimation_process.EstimateNotAssignableError):
         process.map(map_fn)
 def test_construction_with_empty_state_does_not_raise(self):
   initialize_fn = computations.tf_computation()(lambda: ())
   next_fn = computations.tf_computation(())(lambda x: (x, 1.0))
   report_fn = computations.tf_computation(())(lambda x: x)
   try:
     estimation_process.EstimationProcess(initialize_fn, next_fn, report_fn)
   except:  # pylint: disable=bare-except
     self.fail('Could not construct an EstimationProcess with empty state.')
 def test_federated_report_state_not_assignable(self):
   initialize_fn = computations.federated_computation()(
       lambda: intrinsics.federated_value(0, placements.SERVER))
   next_fn = computations.federated_computation(
       initialize_fn.type_signature.result)(lambda state: state)
   report_fn = computations.federated_computation(
       computation_types.FederatedType(
           tf.int32, placements.CLIENTS))(lambda state: state)
   with self.assertRaises(errors.TemplateStateNotAssignableError):
     estimation_process.EstimationProcess(initialize_fn, next_fn, report_fn)
Exemple #6
0
def _constant_process(value):
    """Creates an `EstimationProcess` that reports a constant value."""
    init_fn = computations.federated_computation(
        lambda: intrinsics.federated_value((), placements.SERVER))
    next_fn = computations.federated_computation(
        lambda state, value: state, init_fn.type_signature.result,
        computation_types.at_clients(NORM_TF_TYPE))
    report_fn = computations.federated_computation(
        lambda state: intrinsics.federated_value(value, placements.SERVER),
        init_fn.type_signature.result)
    return estimation_process.EstimationProcess(init_fn, next_fn, report_fn)
  def test_mapped_process_as_expected(self):
    process = estimation_process.EstimationProcess(test_initialize_fn,
                                                   test_next_fn, test_report_fn)
    mapped_process = process.map(test_map_fn)

    self.assertIsInstance(mapped_process, estimation_process.EstimationProcess)
    self.assertEqual(process.initialize, mapped_process.initialize)
    self.assertEqual(process.next, mapped_process.next)
    self.assertEqual(process.report.type_signature.parameter,
                     mapped_process.report.type_signature.parameter)
    self.assertEqual(test_map_fn.type_signature.result,
                     mapped_process.report.type_signature.result)
  def test_construction_with_unknown_dimension_does_not_raise(self):
    initialize_fn = computations.tf_computation()(
        lambda: tf.constant([], dtype=tf.string))

    @computations.tf_computation(
        computation_types.TensorType(shape=[None], dtype=tf.string))
    def next_fn(strings):
      return tf.concat([strings, tf.constant(['abc'])], axis=0)

    @computations.tf_computation(
        computation_types.TensorType(shape=[None], dtype=tf.string))
    def report_fn(strings):
      return strings

    try:
      estimation_process.EstimationProcess(initialize_fn, next_fn, report_fn)
    except:  # pylint: disable=bare-except
      self.fail('Could not construct an EstimationProcess with parameter types '
                'with statically unknown shape.')
  def test_federated_mapped_process_as_expected(self):
    initialize_fn = computations.federated_computation()(
        lambda: intrinsics.federated_value(0, placements.SERVER))
    next_fn = computations.federated_computation(
        initialize_fn.type_signature.result)(lambda state: state)
    report_fn = computations.federated_computation(
        initialize_fn.type_signature.result)(
            lambda state: intrinsics.federated_map(test_report_fn, state))
    process = estimation_process.EstimationProcess(initialize_fn, next_fn,
                                                   report_fn)

    map_fn = computations.federated_computation(
        report_fn.type_signature.result)(
            lambda estimate: intrinsics.federated_map(test_map_fn, estimate))
    mapped_process = process.map(map_fn)

    self.assertIsInstance(mapped_process, estimation_process.EstimationProcess)
    self.assertEqual(process.initialize, mapped_process.initialize)
    self.assertEqual(process.next, mapped_process.next)
    self.assertEqual(process.report.type_signature.parameter,
                     mapped_process.report.type_signature.parameter)
    self.assertEqual(map_fn.type_signature.result,
                     mapped_process.report.type_signature.result)
 def test_construction_does_not_raise(self):
     try:
         estimation_process.EstimationProcess(test_initialize_fn,
                                              test_next_fn, test_report_fn)
     except:  # pylint: disable=bare-except
         self.fail('Could not construct a valid EstimationProcess.')
 def test_report_state_not_assignable(self):
     report_fn = computations.tf_computation(
         tf.float32)(lambda estimate: estimate)
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         estimation_process.EstimationProcess(test_initialize_fn,
                                              test_next_fn, report_fn)
Exemple #12
0
    def test_increasing_zero_clip_sum(self):
        # Tests when zeroing and clipping are performed with non-integer clips.
        # Zeroing norm grows by 0.75 each time, clipping norm grows by 0.25.

        @computations.federated_computation(_float_at_server,
                                            _float_at_clients)
        def zeroing_next_fn(state, value):
            del value
            return intrinsics.federated_map(
                computations.tf_computation(lambda x: x + 0.75, tf.float32),
                state)

        @computations.federated_computation(_float_at_server,
                                            _float_at_clients)
        def clipping_next_fn(state, value):
            del value
            return intrinsics.federated_map(
                computations.tf_computation(lambda x: x + 0.25, tf.float32),
                state)

        zeroing_norm_process = estimation_process.EstimationProcess(
            _test_init_fn, zeroing_next_fn, _test_report_fn)
        clipping_norm_process = estimation_process.EstimationProcess(
            _test_init_fn, clipping_next_fn, _test_report_fn)

        factory = robust.zeroing_factory(zeroing_norm_process,
                                         _clipped_sum(clipping_norm_process))

        value_type = computation_types.to_type(tf.float32)
        process = factory.create(value_type)

        state = process.initialize()

        client_data = [1.0, 2.0, 3.0]
        output = process.next(state, client_data)
        self.assertAllClose(1.0, output.measurements['zeroing_norm'])
        self.assertAllClose(1.0,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(2, output.measurements['zeroed_count'])
        self.assertEqual(0, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(1.0, output.result)

        output = process.next(output.state, client_data)
        self.assertAllClose(1.75, output.measurements['zeroing_norm'])
        self.assertAllClose(1.25,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(2, output.measurements['zeroed_count'])
        self.assertEqual(0, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(1.0, output.result)

        output = process.next(output.state, client_data)
        self.assertAllClose(2.5, output.measurements['zeroing_norm'])
        self.assertAllClose(1.5,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(1, output.measurements['zeroed_count'])
        self.assertEqual(1, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(2.5, output.result)

        output = process.next(output.state, client_data)
        self.assertAllClose(3.25, output.measurements['zeroing_norm'])
        self.assertAllClose(1.75,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(0, output.measurements['zeroed_count'])
        self.assertEqual(2, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(4.5, output.result)

        output = process.next(output.state, client_data)
        self.assertAllClose(4.0, output.measurements['zeroing_norm'])
        self.assertAllClose(2.0,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(0, output.measurements['zeroed_count'])
        self.assertEqual(1, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(5.0, output.result)
Exemple #13
0
 def test_next_state_not_assignable(self):
     float_next_fn = tensorflow_computation.tf_computation(
         tf.float32)(lambda state: tf.cast(state, tf.float32))
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         estimation_process.EstimationProcess(test_initialize_fn,
                                              float_next_fn, test_report_fn)
Exemple #14
0
 def test_init_state_not_assignable(self):
     float_initialize_fn = computations.tf_computation()(lambda: 0.0)
     float_next_fn = computations.tf_computation(tf.float32)(lambda x: x)
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         estimation_process.EstimationProcess(float_initialize_fn,
                                              float_next_fn, test_next_fn)
Exemple #15
0
def test_initialize_fn():
    return tf.constant(0, tf.int32)


@computations.tf_computation(tf.int32)
def test_next_fn(state):
    return state + 1


@computations.tf_computation(tf.int32)
def test_get_estimate_fn(state):
    return tf.cast(state, tf.float32) / 2.0


test_estimation_process = estimation_process.EstimationProcess(
    initialize_fn=test_initialize_fn,
    next_fn=test_next_fn,
    get_estimate_fn=test_get_estimate_fn)


@computations.tf_computation(tf.float32)
def test_transform_fn(arg):
    return 3.0 * arg + 1.0


class EstimationProcessTest(test_case.TestCase):
    def test_get_estimate_not_tff_computation_raises(self):
        with self.assertRaisesRegex(TypeError,
                                    r'Expected .*\.Computation, .*'):
            estimation_process.EstimationProcess(
                initialize_fn=test_initialize_fn,
                next_fn=test_next_fn,
Exemple #16
0
def _test_estimation_process(factor):
    return estimation_process.EstimationProcess(_test_float_init_fn(factor),
                                                _test_float_next_fn(factor),
                                                _test_float_report_fn)
def _test_norm_process(init_fn=_test_init_fn,
                       next_fn=_test_next_fn,
                       report_fn=_test_report_fn):
  return estimation_process.EstimationProcess(init_fn, next_fn, report_fn)
 def test_init_param_not_empty_raises(self):
     one_arg_initialize_fn = computations.tf_computation(
         tf.int32)(lambda x: x)
     with self.assertRaises(errors.TemplateInitFnParamNotEmptyError):
         estimation_process.EstimationProcess(one_arg_initialize_fn,
                                              test_next_fn, test_report_fn)