def test_build_dataset_split_fn_none_recon_epochs_variable(self):
    # 3 batches.
    client_dataset = tf.data.Dataset.range(6).batch(2)

    split_dataset_fn = federated_trainer_utils.build_dataset_split_fn(
        recon_epochs_max=8,
        recon_epochs_constant=False,
        recon_steps_max=None,
        post_recon_epochs=1,
        post_recon_steps_max=None,
        split_dataset=False,
        split_dataset_strategy=None,
        split_dataset_proportion=None)

    round_num = tf.constant(1, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0, 1], [2, 3], [4, 5]])
    self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5]])

    round_num = tf.constant(2, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list,
                        [[0, 1], [2, 3], [4, 5], [0, 1], [2, 3], [4, 5]])
    self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5]])
  def test_build_dataset_split_fn_aggregated_recon_epochs_variable_max_steps_multiple_post_epochs(
      self):
    # 3 batches.
    client_dataset = tf.data.Dataset.range(6).batch(2)

    split_dataset_fn = federated_trainer_utils.build_dataset_split_fn(
        recon_epochs_max=8,
        recon_epochs_constant=False,
        recon_steps_max=4,
        post_recon_epochs=2,
        post_recon_steps_max=None,
        split_dataset=True,
        split_dataset_strategy=federated_trainer_utils
        .SPLIT_STRATEGY_AGGREGATED,
        split_dataset_proportion=2)

    round_num = tf.constant(1, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0, 1], [2, 3]])
    self.assertAllEqual(post_recon_list, [[4, 5], [4, 5]])

    round_num = tf.constant(2, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0, 1], [2, 3], [0, 1], [2, 3]])
    self.assertAllEqual(post_recon_list, [[4, 5], [4, 5]])
  def test_build_dataset_split_fn_aggregated_recon_max_steps(self):
    # 3 batches.
    client_dataset = tf.data.Dataset.range(6).batch(2)

    split_dataset_fn = federated_trainer_utils.build_dataset_split_fn(
        recon_epochs_max=2,
        recon_epochs_constant=True,
        recon_steps_max=4,
        post_recon_epochs=1,
        post_recon_steps_max=None,
        split_dataset=True,
        split_dataset_strategy=federated_trainer_utils
        .SPLIT_STRATEGY_AGGREGATED,
        split_dataset_proportion=2)
    # Round number shouldn't matter.
    recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset, 3)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0, 1], [2, 3], [0, 1], [2, 3]])
    self.assertAllEqual(post_recon_list, [[4, 5]])

    # Adding more steps than the number of actual steps has no effect.
    split_dataset_fn = federated_trainer_utils.build_dataset_split_fn(
        recon_epochs_max=2,
        recon_epochs_constant=True,
        recon_steps_max=7,
        post_recon_epochs=1,
        post_recon_steps_max=None,
        split_dataset=True,
        split_dataset_strategy=federated_trainer_utils
        .SPLIT_STRATEGY_AGGREGATED,
        split_dataset_proportion=2)
    # Round number shouldn't matter.
    recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset, 3)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0, 1], [2, 3], [0, 1], [2, 3]])
    self.assertAllEqual(post_recon_list, [[4, 5]])
  def test_build_dataset_split_fn_skip(self):
    # 3 batches.
    client_dataset = tf.data.Dataset.range(6).batch(2)

    split_dataset_fn = federated_trainer_utils.build_dataset_split_fn(
        recon_epochs_max=2,
        recon_epochs_constant=True,
        recon_steps_max=None,
        post_recon_epochs=1,
        post_recon_steps_max=None,
        split_dataset=True,
        split_dataset_strategy=federated_trainer_utils.SPLIT_STRATEGY_SKIP,
        split_dataset_proportion=2)
    # Round number shouldn't matter.
    recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset, 3)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0, 1], [4, 5], [0, 1], [4, 5]])
    self.assertAllEqual(post_recon_list, [[2, 3]])
  def test_build_dataset_split_aggregated_fn_split_dataset_zero_batches(self):
    """Ensures clients without any data don't fail."""
    # 0 batches.
    client_dataset = tf.data.Dataset.range(0).batch(2)

    split_dataset_fn = federated_trainer_utils.build_dataset_split_fn(
        recon_epochs_max=1,
        recon_epochs_constant=True,
        recon_steps_max=None,
        post_recon_epochs=1,
        post_recon_steps_max=None,
        split_dataset=True,
        split_dataset_strategy=federated_trainer_utils
        .SPLIT_STRATEGY_AGGREGATED,
        split_dataset_proportion=10)

    # Round number doesn't matter.
    round_num = tf.constant(1, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [])
    self.assertAllEqual(post_recon_list, [])

    # Round number doesn't matter.
    round_num = tf.constant(2, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [])
    self.assertAllEqual(post_recon_list, [])
  def test_build_dataset_split_fn_skip_post_recon_multiple_epochs_max_steps(
      self):
    # 3 batches.
    client_dataset = tf.data.Dataset.range(6).batch(2)

    split_dataset_fn = federated_trainer_utils.build_dataset_split_fn(
        recon_epochs_max=1,
        recon_epochs_constant=True,
        recon_steps_max=None,
        post_recon_epochs=2,
        post_recon_steps_max=4,
        split_dataset=True,
        split_dataset_strategy=federated_trainer_utils.SPLIT_STRATEGY_SKIP,
        split_dataset_proportion=2)

    # Round number doesn't matter.
    round_num = tf.constant(1, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0, 1], [4, 5]])
    self.assertAllEqual(post_recon_list, [[2, 3], [2, 3]])

    # Round number doesn't matter.
    round_num = tf.constant(2, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0, 1], [4, 5]])
    self.assertAllEqual(post_recon_list, [[2, 3], [2, 3]])