def test_next_not_tff_computation_raises(self): with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): finalizers.FinalizerProcess( initialize_fn=test_initialize_fn, next_fn=lambda state, w, u: MeasuredProcessOutput( state, w + u, ()))
def test_two_param_next_raises(self): @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE) def next_fn(state, weights): return MeasuredProcessOutput(state, weights, server_zero()) with self.assertRaises(errors.TemplateNextFnNumArgsError): finalizers.FinalizerProcess(test_initialize_fn, next_fn)
def test_next_return_tuple_raises(self): @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT) def tuple_next_fn(state, weights, update): return state, test_finalizer_result(weights, update), server_zero() with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError): finalizers.FinalizerProcess(test_initialize_fn, tuple_next_fn)
def test_non_server_placed_next_measurements_raises(self): @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT) def next_fn(state, weights, update): return MeasuredProcessOutput( state, test_finalizer_result(weights, update), intrinsics.federated_value(1.0, placements.CLIENTS)) with self.assertRaises(errors.TemplatePlacementError): finalizers.FinalizerProcess(test_initialize_fn, next_fn)
def test_next_return_odict_raises(self): @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT) def odict_next_fn(state, weights, update): return collections.OrderedDict(state=state, result=test_finalizer_result( weights, update), measurements=server_zero()) with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError): finalizers.FinalizerProcess(test_initialize_fn, odict_next_fn)
def test_next_return_namedtuple_raises(self): measured_process_output = collections.namedtuple( 'MeasuredProcessOutput', ['state', 'result', 'measurements']) @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT) def namedtuple_next_fn(state, weights, update): return measured_process_output( state, test_finalizer_result(weights, update), server_zero()) with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError): finalizers.FinalizerProcess(test_initialize_fn, namedtuple_next_fn)
def test_next_state_not_assignable(self): @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT) def float_next_fn(state, weights, update): del state return MeasuredProcessOutput( intrinsics.federated_value(0.0, placements.SERVER), test_finalizer_result(weights, update), intrinsics.federated_value(1, placements.SERVER)) with self.assertRaises(errors.TemplateStateNotAssignableError): finalizers.FinalizerProcess(test_initialize_fn, float_next_fn)
def test_init_tuple_of_federated_types_raises(self): initialize_fn = computations.federated_computation()( lambda: (server_zero(), server_zero())) @computations.federated_computation( initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, SERVER_FLOAT) def next_fn(state, weights, update): return MeasuredProcessOutput( state, test_finalizer_result(weights, update), server_zero()) with self.assertRaises(errors.TemplateNotFederatedError): finalizers.FinalizerProcess(initialize_fn, next_fn)
def test_non_server_placed_next_weight_param_raises(self): @computations.federated_computation(SERVER_INT, computation_types.at_clients( MODEL_WEIGHTS_TYPE.member), SERVER_FLOAT) def next_fn(state, weights, update): return MeasuredProcessOutput( state, test_finalizer_result(intrinsics.federated_sum(weights), update), server_zero()) with self.assertRaises(errors.TemplatePlacementError): finalizers.FinalizerProcess(test_initialize_fn, next_fn)
def test_non_federated_init_next_raises(self): initialize_fn = computations.tf_computation(lambda: 0) @computations.tf_computation( tf.int32, computation_types.to_type(model_utils.ModelWeights(tf.float32, ())), tf.float32) def next_fn(state, weights, update): new_weigths = model_utils.ModelWeights(weights.trainable + update, ()) return MeasuredProcessOutput(state, new_weigths, 0) with self.assertRaises(errors.TemplateNotFederatedError): finalizers.FinalizerProcess(initialize_fn, next_fn)
def test_finalizer(): @computations.federated_computation( empty_init_fn.type_signature.result, computation_types.at_server(MODEL_WEIGHTS_TYPE), computation_types.at_server(FLOAT_TYPE)) def next_fn(state, weights, updates): new_weights = intrinsics.federated_map( computations.tf_computation(lambda x, y: x + y), (weights.trainable, updates)) new_weights = intrinsics.federated_zip( model_utils.ModelWeights(new_weights, ())) return measured_process.MeasuredProcessOutput(state, new_weights, empty_at_server()) return finalizers.FinalizerProcess(empty_init_fn, next_fn)
def test_result_not_assignable_to_weight_raises(self): bad_cast_fn = computations.tf_computation( lambda x: tf.nest.map_structure(lambda y: tf.cast(y, tf.float64), x )) @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT) def next_fn(state, weights, update): return MeasuredProcessOutput( state, intrinsics.federated_map( bad_cast_fn, test_finalizer_result(weights, update)), server_zero()) with self.assertRaises(finalizers.FinalizerResultTypeError): finalizers.FinalizerProcess(test_initialize_fn, next_fn)
def test_construction_with_empty_state_does_not_raise(self): initialize_fn = computations.federated_computation()( lambda: intrinsics.federated_value((), placements.SERVER)) @computations.federated_computation( initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, SERVER_FLOAT) def next_fn(state, weights, update): return MeasuredProcessOutput( state, test_finalizer_result(weights, update), intrinsics.federated_value(1, placements.SERVER)) try: finalizers.FinalizerProcess(initialize_fn, next_fn) except: # pylint: disable=bare-except self.fail( 'Could not construct an FinalizerProcess with empty state.')
def test_bad_next_weights_param_type_raises(self): bad_model_weights_type = computation_types.at_server( computation_types.to_type( collections.OrderedDict(trainable=tf.float32, non_trainable=()))) @computations.federated_computation(SERVER_INT, bad_model_weights_type, SERVER_FLOAT) def next_fn(state, weights, update): return MeasuredProcessOutput( state, intrinsics.federated_zip( model_utils.ModelWeights( federated_add(weights['trainable'], update), ())), server_zero()) with self.assertRaises(finalizers.ModelWeightsTypeError): finalizers.FinalizerProcess(test_initialize_fn, next_fn)
def test_init_param_not_empty_raises(self): one_arg_initialize_fn = computations.federated_computation(SERVER_INT)( lambda x: x) with self.assertRaises(errors.TemplateInitFnParamNotEmptyError): finalizers.FinalizerProcess(one_arg_initialize_fn, test_next_fn)
def test_init_state_not_assignable(self): float_initialize_fn = computations.federated_computation()( lambda: intrinsics.federated_value(0.0, placements.SERVER)) with self.assertRaises(errors.TemplateStateNotAssignableError): finalizers.FinalizerProcess(float_initialize_fn, test_next_fn)
def test_init_not_tff_computation_raises(self): with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): finalizers.FinalizerProcess(initialize_fn=lambda: 0, next_fn=test_next_fn)
def test_construction_does_not_raise(self): try: finalizers.FinalizerProcess(test_initialize_fn, test_next_fn) except: # pylint: disable=bare-except self.fail('Could not construct a valid FinalizerProcess.')