def test_partitions_value_with_no_clients_arguments(self): value = 0 type_signature = computation_types.at_server(tf.int32) num_desired_subrounds = 2 partitioned_value = mergeable_comp_execution_context._split_value_into_subrounds( value, type_signature, num_desired_subrounds) self.assertEqual(partitioned_value, [0, 0])
def assertRoundTripEqual(self, value, type_signature, expected_round_trip_value): num_desired_subrounds = 2 partitioned_value = mergeable_comp_execution_context._split_value_into_subrounds( value, type_signature, num_desired_subrounds) self.assertEqual( mergeable_comp_execution_context._repackage_partitioned_values( partitioned_value, type_signature), expected_round_trip_value)
def test_partitions_fewer_clients_than_rounds_into_nonempty_rounds(self): value = [0, 1] type_signature = computation_types.at_clients(tf.int32) num_desired_subrounds = 5 partitioned_value = mergeable_comp_execution_context._split_value_into_subrounds( value, type_signature, num_desired_subrounds) expected_partitioning = [[0], [1]] self.assertEqual(partitioned_value, expected_partitioning)
def test_partitions_client_placed_value_into_subrounds(self): value = list(range(10)) type_signature = computation_types.at_clients(tf.int32) num_desired_subrounds = 5 partitioned_value = mergeable_comp_execution_context._split_value_into_subrounds( value, type_signature, num_desired_subrounds) expected_partitioning = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] self.assertEqual(partitioned_value, expected_partitioning)
def test_replicates_all_equal_clients_argument(self): value = (0, 1) type_signature = computation_types.StructType([ (None, computation_types.at_server(tf.int32)), (None, computation_types.at_clients(tf.int32, all_equal=True)) ]) num_desired_subrounds = 2 partitioned_value = mergeable_comp_execution_context._split_value_into_subrounds( value, type_signature, num_desired_subrounds) self.assertEqual(partitioned_value, [(0, 1), (0, 1)])
def test_wraps_value_with_empty_client_argument(self): value = (0, []) type_signature = computation_types.StructType([ (None, computation_types.at_server(tf.int32)), (None, computation_types.at_clients(tf.int32)) ]) num_desired_subrounds = 2 partitioned_value = mergeable_comp_execution_context._split_value_into_subrounds( value, type_signature, num_desired_subrounds) self.assertEqual(partitioned_value, [(0, [])])
def test_partitions_clients_placed_struct_elem_into_subrounds(self): value = (0, list(range(10))) server_placed_name = 'a' clients_placed_name = 'b' type_signature = computation_types.StructType([ (server_placed_name, computation_types.at_server(tf.int32)), (clients_placed_name, computation_types.at_clients(tf.int32)) ]) num_desired_subrounds = 5 expected_partitioning = [] for j in range(0, 10, 2): expected_struct_partition = structure.Struct([(server_placed_name, 0), (clients_placed_name, [j, j + 1])]) expected_partitioning.append(expected_struct_partition) partitioned_value = mergeable_comp_execution_context._split_value_into_subrounds( value, type_signature, num_desired_subrounds) self.assertEqual(partitioned_value, expected_partitioning)