def test_partitions_value_with_no_clients_arguments(self):
   value = 0
   type_signature = computation_types.at_server(tf.int32)
   num_desired_subrounds = 2
   partitioned_value = mergeable_comp_execution_context._split_value_into_subrounds(
       value, type_signature, num_desired_subrounds)
   self.assertEqual(partitioned_value, [0, 0])
 def assertRoundTripEqual(self, value, type_signature,
                          expected_round_trip_value):
   num_desired_subrounds = 2
   partitioned_value = mergeable_comp_execution_context._split_value_into_subrounds(
       value, type_signature, num_desired_subrounds)
   self.assertEqual(
       mergeable_comp_execution_context._repackage_partitioned_values(
           partitioned_value, type_signature), expected_round_trip_value)
 def test_partitions_fewer_clients_than_rounds_into_nonempty_rounds(self):
   value = [0, 1]
   type_signature = computation_types.at_clients(tf.int32)
   num_desired_subrounds = 5
   partitioned_value = mergeable_comp_execution_context._split_value_into_subrounds(
       value, type_signature, num_desired_subrounds)
   expected_partitioning = [[0], [1]]
   self.assertEqual(partitioned_value, expected_partitioning)
 def test_partitions_client_placed_value_into_subrounds(self):
   value = list(range(10))
   type_signature = computation_types.at_clients(tf.int32)
   num_desired_subrounds = 5
   partitioned_value = mergeable_comp_execution_context._split_value_into_subrounds(
       value, type_signature, num_desired_subrounds)
   expected_partitioning = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
   self.assertEqual(partitioned_value, expected_partitioning)
 def test_replicates_all_equal_clients_argument(self):
   value = (0, 1)
   type_signature = computation_types.StructType([
       (None, computation_types.at_server(tf.int32)),
       (None, computation_types.at_clients(tf.int32, all_equal=True))
   ])
   num_desired_subrounds = 2
   partitioned_value = mergeable_comp_execution_context._split_value_into_subrounds(
       value, type_signature, num_desired_subrounds)
   self.assertEqual(partitioned_value, [(0, 1), (0, 1)])
 def test_wraps_value_with_empty_client_argument(self):
   value = (0, [])
   type_signature = computation_types.StructType([
       (None, computation_types.at_server(tf.int32)),
       (None, computation_types.at_clients(tf.int32))
   ])
   num_desired_subrounds = 2
   partitioned_value = mergeable_comp_execution_context._split_value_into_subrounds(
       value, type_signature, num_desired_subrounds)
   self.assertEqual(partitioned_value, [(0, [])])
  def test_partitions_clients_placed_struct_elem_into_subrounds(self):
    value = (0, list(range(10)))
    server_placed_name = 'a'
    clients_placed_name = 'b'
    type_signature = computation_types.StructType([
        (server_placed_name, computation_types.at_server(tf.int32)),
        (clients_placed_name, computation_types.at_clients(tf.int32))
    ])

    num_desired_subrounds = 5
    expected_partitioning = []
    for j in range(0, 10, 2):
      expected_struct_partition = structure.Struct([(server_placed_name, 0),
                                                    (clients_placed_name,
                                                     [j, j + 1])])
      expected_partitioning.append(expected_struct_partition)
    partitioned_value = mergeable_comp_execution_context._split_value_into_subrounds(
        value, type_signature, num_desired_subrounds)
    self.assertEqual(partitioned_value, expected_partitioning)