def test_next_not_tff_computation_raises(self): with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): client_works.ClientWorkProcess( initialize_fn=test_initialize_fn, next_fn=lambda state, w, d: MeasuredProcessOutput( state, w + d, ()))
def test_next_return_tuple_raises(self): @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE) def tuple_next_fn(state, weights, data): return (state, test_client_result(weights, data), server_zero()) with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError): client_works.ClientWorkProcess(test_initialize_fn, tuple_next_fn)
def test_two_param_next_raises(self): @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE) def next_fn(state, weights): return MeasuredProcessOutput(state, weights.trainable, server_zero()) with self.assertRaises(errors.TemplateNextFnNumArgsError): client_works.ClientWorkProcess(test_initialize_fn, next_fn)
def test_non_server_placed_next_measurements_raises(self): @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE) def next_fn(state, weights, data): return MeasuredProcessOutput( state, test_client_result(weights, data), intrinsics.federated_value(1.0, placements.CLIENTS)) with self.assertRaises(errors.TemplatePlacementError): client_works.ClientWorkProcess(test_initialize_fn, next_fn)
def test_next_state_not_assignable(self): @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE) def float_next_fn(state, weights, data): del state return MeasuredProcessOutput( intrinsics.federated_value(0.0, placements.SERVER), test_client_result(weights, data), intrinsics.federated_value(1, placements.SERVER)) with self.assertRaises(errors.TemplateStateNotAssignableError): client_works.ClientWorkProcess(test_initialize_fn, float_next_fn)
def test_non_sequence_next_data_param_raises(self): @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT) def next_fn(state, weights, data): return MeasuredProcessOutput( state, intrinsics.federated_zip( client_works.ClientResult( federated_add(weights.trainable, data), client_one())), server_zero()) with self.assertRaises(client_works.ClientDataTypeError): client_works.ClientWorkProcess(test_initialize_fn, next_fn)
def test_non_zipped_next_result_raises(self): @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE) def next_fn(state, weights, data): reduced_data = intrinsics.federated_map(tf_data_sum, data) return MeasuredProcessOutput( state, client_works.ClientResult( federated_add(weights.trainable, reduced_data), client_one()), server_zero()) with self.assertRaises(errors.TemplatePlacementError): client_works.ClientWorkProcess(test_initialize_fn, next_fn)
def test_non_clients_placed_next_weights_param_raises(self): @computations.federated_computation(SERVER_INT, computation_types.at_server( MODEL_WEIGHTS_TYPE.member), CLIENTS_FLOAT_SEQUENCE) def next_fn(state, weights, data): return MeasuredProcessOutput( state, test_client_result(intrinsics.federated_broadcast(weights), data), server_zero()) with self.assertRaises(errors.TemplatePlacementError): client_works.ClientWorkProcess(test_initialize_fn, next_fn)
def test_non_server_placed_init_state_raises(self): initialize_fn = computations.federated_computation( lambda: intrinsics.federated_value(0, placements.CLIENTS)) @computations.federated_computation( initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE) def next_fn(state, weights, data): return MeasuredProcessOutput(state, test_client_result(weights, data), server_zero()) with self.assertRaises(errors.TemplatePlacementError): client_works.ClientWorkProcess(initialize_fn, next_fn)
def test_init_tuple_of_federated_types_raises(self): initialize_fn = computations.federated_computation()( lambda: (server_zero(), server_zero())) @computations.federated_computation( initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE) def next_fn(state, weights, data): return MeasuredProcessOutput(state, test_client_result(weights, data), server_zero()) with self.assertRaises(errors.TemplateNotFederatedError): client_works.ClientWorkProcess(initialize_fn, next_fn)
def test_non_federated_init_next_raises(self): initialize_fn = computations.tf_computation(lambda: 0) @computations.tf_computation(tf.int32, MODEL_WEIGHTS_TYPE.member, computation_types.SequenceType(tf.float32) ) def next_fn(state, weights, data): return MeasuredProcessOutput( state, client_works.ClientResult( weights.trainable + tf_data_sum(data), ()), ()) with self.assertRaises(errors.TemplateNotFederatedError): client_works.ClientWorkProcess(initialize_fn, next_fn)
def test_next_return_namedtuple_raises(self): measured_process_output = collections.namedtuple( 'MeasuredProcessOutput', ['state', 'result', 'measurements']) @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE) def namedtuple_next_fn(state, weights, data): return measured_process_output(state, test_client_result(weights, data), server_zero()) with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError): client_works.ClientWorkProcess(test_initialize_fn, namedtuple_next_fn)
def test_incorrect_client_result_container_raises(self): @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE) def next_fn(state, weights, data): reduced_data = intrinsics.federated_map(tf_data_sum, data) bad_client_result = intrinsics.federated_zip( collections.OrderedDict(update=federated_add( weights.trainable, reduced_data), update_weight=client_one())) return MeasuredProcessOutput(state, bad_client_result, server_zero()) with self.assertRaises(client_works.ClientResultTypeError): client_works.ClientWorkProcess(test_initialize_fn, next_fn)
def test_non_clients_placed_next_data_param_raises(self): server_sequence_float_type = computation_types.at_server( computation_types.SequenceType(tf.float32)) @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, server_sequence_float_type) def next_fn(state, weights, data): return MeasuredProcessOutput( state, test_client_result(weights, intrinsics.federated_broadcast(data)), server_zero()) with self.assertRaises(errors.TemplatePlacementError): client_works.ClientWorkProcess(test_initialize_fn, next_fn)
def test_bad_next_weights_param_type_raises(self): bad_model_weights_type = computation_types.at_clients( computation_types.to_type( collections.OrderedDict(trainable=tf.float32, non_trainable=()))) @computations.federated_computation(SERVER_INT, bad_model_weights_type, CLIENTS_FLOAT_SEQUENCE) def next_fn(state, weights, data): return MeasuredProcessOutput(state, test_client_result(weights, data), server_zero()) with self.assertRaises(client_works.ModelWeightsTypeError): client_works.ClientWorkProcess(test_initialize_fn, next_fn)
def test_client_work(): @computations.tf_computation() def make_result(value, data): return client_works.ClientResult(update=value.trainable, update_weight=data.reduce( 0.0, lambda x, y: x + y)) @computations.federated_computation( empty_init_fn.type_signature.result, computation_types.at_clients(MODEL_WEIGHTS_TYPE), CLIENTS_SEQUENCE_FLOAT_TYPE) def next_fn(state, value, client_data): result = intrinsics.federated_map(make_result, (value, client_data)) return measured_process.MeasuredProcessOutput(state, result, empty_at_server()) return client_works.ClientWorkProcess(empty_init_fn, next_fn)
def test_construction_with_empty_state_does_not_raise(self): initialize_fn = computations.federated_computation()( lambda: intrinsics.federated_value((), placements.SERVER)) @computations.federated_computation( initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE) def next_fn(state, weights, data): return MeasuredProcessOutput( state, test_client_result(weights, data), intrinsics.federated_value(1, placements.SERVER)) try: client_works.ClientWorkProcess(initialize_fn, next_fn) except: # pylint: disable=bare-except self.fail( 'Could not construct an ClientWorkProcess with empty state.')
def test_trainable_weights_not_asignable_from_update_raises(self): bad_cast_fn = computations.tf_computation( lambda x: tf.cast(x, tf.float64)) @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE) def next_fn(state, weights, data): reduced_data = intrinsics.federated_map(tf_data_sum, data) not_assignable_update = intrinsics.federated_map( bad_cast_fn, federated_add(weights.trainable, reduced_data)) return MeasuredProcessOutput( state, intrinsics.federated_zip( client_works.ClientResult(not_assignable_update, client_one())), server_zero()) with self.assertRaises(client_works.ClientResultTypeError): client_works.ClientWorkProcess(test_initialize_fn, next_fn)
def test_init_not_tff_computation_raises(self): with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): client_works.ClientWorkProcess(initialize_fn=lambda: 0, next_fn=test_next_fn)
def test_construction_does_not_raise(self): try: client_works.ClientWorkProcess(test_initialize_fn, test_next_fn) except: # pylint: disable=bare-except self.fail('Could not construct a valid ClientWorkProcess.')
def test_init_param_not_empty_raises(self): one_arg_initialize_fn = computations.federated_computation(SERVER_INT)( lambda x: x) with self.assertRaises(errors.TemplateInitFnParamNotEmptyError): client_works.ClientWorkProcess(one_arg_initialize_fn, test_next_fn)
def test_init_state_not_assignable(self): float_initialize_fn = computations.federated_computation()( lambda: intrinsics.federated_value(0.0, placements.SERVER)) with self.assertRaises(errors.TemplateStateNotAssignableError): client_works.ClientWorkProcess(float_initialize_fn, test_next_fn)