def test_next_state_not_assignable_tuple_result(self): float_next_fn = computations.tf_computation( tf.float32, tf.float32)(lambda state, x: (tf.cast(state, tf.float32), x)) with self.assertRaises(errors.TemplateStateNotAssignableError): estimation_process.EstimationProcess(test_initialize_fn, float_next_fn, test_report_fn)
def test_next_not_tff_computation_raises(self): with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): estimation_process.EstimationProcess( initialize_fn=test_initialize_fn, next_fn=lambda state: state, report_fn=test_report_fn)
def test_map_estimate_not_assignable(self): map_fn = computations.tf_computation( tf.int32)(lambda estimate: estimate) process = estimation_process.EstimationProcess(test_initialize_fn, test_next_fn, test_report_fn) with self.assertRaises(estimation_process.EstimateNotAssignableError): process.map(map_fn)
def test_construction_with_empty_state_does_not_raise(self): initialize_fn = computations.tf_computation()(lambda: ()) next_fn = computations.tf_computation(())(lambda x: (x, 1.0)) report_fn = computations.tf_computation(())(lambda x: x) try: estimation_process.EstimationProcess(initialize_fn, next_fn, report_fn) except: # pylint: disable=bare-except self.fail('Could not construct an EstimationProcess with empty state.')
def test_federated_report_state_not_assignable(self): initialize_fn = computations.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = computations.federated_computation( initialize_fn.type_signature.result)(lambda state: state) report_fn = computations.federated_computation( computation_types.FederatedType( tf.int32, placements.CLIENTS))(lambda state: state) with self.assertRaises(errors.TemplateStateNotAssignableError): estimation_process.EstimationProcess(initialize_fn, next_fn, report_fn)
def _constant_process(value): """Creates an `EstimationProcess` that reports a constant value.""" init_fn = computations.federated_computation( lambda: intrinsics.federated_value((), placements.SERVER)) next_fn = computations.federated_computation( lambda state, value: state, init_fn.type_signature.result, computation_types.at_clients(NORM_TF_TYPE)) report_fn = computations.federated_computation( lambda state: intrinsics.federated_value(value, placements.SERVER), init_fn.type_signature.result) return estimation_process.EstimationProcess(init_fn, next_fn, report_fn)
def test_mapped_process_as_expected(self): process = estimation_process.EstimationProcess(test_initialize_fn, test_next_fn, test_report_fn) mapped_process = process.map(test_map_fn) self.assertIsInstance(mapped_process, estimation_process.EstimationProcess) self.assertEqual(process.initialize, mapped_process.initialize) self.assertEqual(process.next, mapped_process.next) self.assertEqual(process.report.type_signature.parameter, mapped_process.report.type_signature.parameter) self.assertEqual(test_map_fn.type_signature.result, mapped_process.report.type_signature.result)
def test_construction_with_unknown_dimension_does_not_raise(self): initialize_fn = computations.tf_computation()( lambda: tf.constant([], dtype=tf.string)) @computations.tf_computation( computation_types.TensorType(shape=[None], dtype=tf.string)) def next_fn(strings): return tf.concat([strings, tf.constant(['abc'])], axis=0) @computations.tf_computation( computation_types.TensorType(shape=[None], dtype=tf.string)) def report_fn(strings): return strings try: estimation_process.EstimationProcess(initialize_fn, next_fn, report_fn) except: # pylint: disable=bare-except self.fail('Could not construct an EstimationProcess with parameter types ' 'with statically unknown shape.')
def test_federated_mapped_process_as_expected(self): initialize_fn = computations.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = computations.federated_computation( initialize_fn.type_signature.result)(lambda state: state) report_fn = computations.federated_computation( initialize_fn.type_signature.result)( lambda state: intrinsics.federated_map(test_report_fn, state)) process = estimation_process.EstimationProcess(initialize_fn, next_fn, report_fn) map_fn = computations.federated_computation( report_fn.type_signature.result)( lambda estimate: intrinsics.federated_map(test_map_fn, estimate)) mapped_process = process.map(map_fn) self.assertIsInstance(mapped_process, estimation_process.EstimationProcess) self.assertEqual(process.initialize, mapped_process.initialize) self.assertEqual(process.next, mapped_process.next) self.assertEqual(process.report.type_signature.parameter, mapped_process.report.type_signature.parameter) self.assertEqual(map_fn.type_signature.result, mapped_process.report.type_signature.result)
def test_construction_does_not_raise(self): try: estimation_process.EstimationProcess(test_initialize_fn, test_next_fn, test_report_fn) except: # pylint: disable=bare-except self.fail('Could not construct a valid EstimationProcess.')
def test_report_state_not_assignable(self): report_fn = computations.tf_computation( tf.float32)(lambda estimate: estimate) with self.assertRaises(errors.TemplateStateNotAssignableError): estimation_process.EstimationProcess(test_initialize_fn, test_next_fn, report_fn)
def test_increasing_zero_clip_sum(self): # Tests when zeroing and clipping are performed with non-integer clips. # Zeroing norm grows by 0.75 each time, clipping norm grows by 0.25. @computations.federated_computation(_float_at_server, _float_at_clients) def zeroing_next_fn(state, value): del value return intrinsics.federated_map( computations.tf_computation(lambda x: x + 0.75, tf.float32), state) @computations.federated_computation(_float_at_server, _float_at_clients) def clipping_next_fn(state, value): del value return intrinsics.federated_map( computations.tf_computation(lambda x: x + 0.25, tf.float32), state) zeroing_norm_process = estimation_process.EstimationProcess( _test_init_fn, zeroing_next_fn, _test_report_fn) clipping_norm_process = estimation_process.EstimationProcess( _test_init_fn, clipping_next_fn, _test_report_fn) factory = robust.zeroing_factory(zeroing_norm_process, _clipped_sum(clipping_norm_process)) value_type = computation_types.to_type(tf.float32) process = factory.create(value_type) state = process.initialize() client_data = [1.0, 2.0, 3.0] output = process.next(state, client_data) self.assertAllClose(1.0, output.measurements['zeroing_norm']) self.assertAllClose(1.0, output.measurements['zeroing']['clipping_norm']) self.assertEqual(2, output.measurements['zeroed_count']) self.assertEqual(0, output.measurements['zeroing']['clipped_count']) self.assertAllClose(1.0, output.result) output = process.next(output.state, client_data) self.assertAllClose(1.75, output.measurements['zeroing_norm']) self.assertAllClose(1.25, output.measurements['zeroing']['clipping_norm']) self.assertEqual(2, output.measurements['zeroed_count']) self.assertEqual(0, output.measurements['zeroing']['clipped_count']) self.assertAllClose(1.0, output.result) output = process.next(output.state, client_data) self.assertAllClose(2.5, output.measurements['zeroing_norm']) self.assertAllClose(1.5, output.measurements['zeroing']['clipping_norm']) self.assertEqual(1, output.measurements['zeroed_count']) self.assertEqual(1, output.measurements['zeroing']['clipped_count']) self.assertAllClose(2.5, output.result) output = process.next(output.state, client_data) self.assertAllClose(3.25, output.measurements['zeroing_norm']) self.assertAllClose(1.75, output.measurements['zeroing']['clipping_norm']) self.assertEqual(0, output.measurements['zeroed_count']) self.assertEqual(2, output.measurements['zeroing']['clipped_count']) self.assertAllClose(4.5, output.result) output = process.next(output.state, client_data) self.assertAllClose(4.0, output.measurements['zeroing_norm']) self.assertAllClose(2.0, output.measurements['zeroing']['clipping_norm']) self.assertEqual(0, output.measurements['zeroed_count']) self.assertEqual(1, output.measurements['zeroing']['clipped_count']) self.assertAllClose(5.0, output.result)
def test_next_state_not_assignable(self): float_next_fn = tensorflow_computation.tf_computation( tf.float32)(lambda state: tf.cast(state, tf.float32)) with self.assertRaises(errors.TemplateStateNotAssignableError): estimation_process.EstimationProcess(test_initialize_fn, float_next_fn, test_report_fn)
def test_init_state_not_assignable(self): float_initialize_fn = computations.tf_computation()(lambda: 0.0) float_next_fn = computations.tf_computation(tf.float32)(lambda x: x) with self.assertRaises(errors.TemplateStateNotAssignableError): estimation_process.EstimationProcess(float_initialize_fn, float_next_fn, test_next_fn)
def test_initialize_fn(): return tf.constant(0, tf.int32) @computations.tf_computation(tf.int32) def test_next_fn(state): return state + 1 @computations.tf_computation(tf.int32) def test_get_estimate_fn(state): return tf.cast(state, tf.float32) / 2.0 test_estimation_process = estimation_process.EstimationProcess( initialize_fn=test_initialize_fn, next_fn=test_next_fn, get_estimate_fn=test_get_estimate_fn) @computations.tf_computation(tf.float32) def test_transform_fn(arg): return 3.0 * arg + 1.0 class EstimationProcessTest(test_case.TestCase): def test_get_estimate_not_tff_computation_raises(self): with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): estimation_process.EstimationProcess( initialize_fn=test_initialize_fn, next_fn=test_next_fn,
def _test_estimation_process(factor): return estimation_process.EstimationProcess(_test_float_init_fn(factor), _test_float_next_fn(factor), _test_float_report_fn)
def _test_norm_process(init_fn=_test_init_fn, next_fn=_test_next_fn, report_fn=_test_report_fn): return estimation_process.EstimationProcess(init_fn, next_fn, report_fn)
def test_init_param_not_empty_raises(self): one_arg_initialize_fn = computations.tf_computation( tf.int32)(lambda x: x) with self.assertRaises(errors.TemplateInitFnParamNotEmptyError): estimation_process.EstimationProcess(one_arg_initialize_fn, test_next_fn, test_report_fn)