def test_type_signature_with_structure_of_ints(self): value = intrinsics.federated_value([1, [1, 1]], placements.CLIENTS) bitwidth = [8, [4, 2]] intrinsic = intrinsics.secure_sum(value, bitwidth) self.assertEqual(intrinsic.type_signature.compact_representation(), '<int32,<int32,int32>>@SERVER')
def test_type_signature_with_int(self): value = intrinsics.federated_value(1, placements.CLIENTS) bitwidth = 8 intrinsic = intrinsics.secure_sum(value, bitwidth) self.assertEqual(intrinsic.type_signature.compact_representation(), 'int32@SERVER')
def next_fn(server_state, client_data): """The `next` function for `computation_utils.IterativeProcess`.""" s2 = intrinsics.federated_map(prepare, server_state) client_input = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([client_data, client_input]) client_updates, client_output = intrinsics.federated_map(work, c3) secure_update = intrinsics.secure_sum(client_updates, 8) s6 = intrinsics.federated_zip([server_state, secure_update]) new_server_state, server_output = intrinsics.federated_map(update, s6) return new_server_state, server_output, client_output
def next_fn(server_state, client_data): """The `next` function for `computation_utils.IterativeProcess`.""" del server_state # Unused client_updates, client_output = intrinsics.federated_map( work, client_data) federated_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.secure_sum(client_updates[1], 8) s5 = intrinsics.federated_zip([federated_update, secure_update]) new_server_state, server_output = intrinsics.federated_map(update, s5) return new_server_state, server_output, client_output
def test_raises_type_error_with_different_structures(self): value = intrinsics.federated_value([1, [1, 1]], placements.CLIENTS) bitwidth = 8 with self.assertRaises(TypeError): intrinsics.secure_sum(value, bitwidth)
def test_raises_type_error_with_bitwith_int_at_server(self): value = intrinsics.federated_value(1, placements.CLIENTS) bitwidth = intrinsics.federated_value(1, placements.SERVER) with self.assertRaises(TypeError): intrinsics.secure_sum(value, bitwidth)