def get_example_cf_compatible_iterative_processes():
  # pyformat: disable
  return (
      ('sum_example',
       get_iterative_process_for_sum_example()),
      ('sum_example_with_no_prepare',
       get_iterative_process_for_sum_example_with_no_prepare()),
      ('sum_example_with_no_broadcast',
       get_iterative_process_for_sum_example_with_no_broadcast()),
      ('sum_example_with_no_federated_aggregate',
       get_iterative_process_for_sum_example_with_no_federated_aggregate()),
      ('sum_example_with_no_federated_secure_sum_bitwidth',
       get_iterative_process_for_sum_example_with_no_federated_secure_sum_bitwidth()),
      ('sum_example_with_no_update',
       get_iterative_process_for_sum_example_with_no_update()),
      ('sum_example_with_no_server_state',
       get_iterative_process_for_sum_example_with_no_server_state()),
      ('minimal_sum_example',
       get_iterative_process_for_minimal_sum_example()),
      ('example_with_unused_lambda_arg',
       mapreduce_test_utils.get_iterative_process_for_example_with_unused_lambda_arg()),
      ('example_with_unused_tf_computation_arg',
       mapreduce_test_utils.get_iterative_process_for_example_with_unused_tf_computation_arg()))
class GetCanonicalFormForIterativeProcessTest(CanonicalFormTestCase,
                                              parameterized.TestCase):

  def test_next_computation_returning_tensor_fails_well(self):
    cf = test_utils.get_temperature_sensor_example()
    it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
    init_result = it.initialize.type_signature.result
    lam = building_blocks.Lambda('x', init_result,
                                 building_blocks.Reference('x', init_result))
    bad_it = iterative_process.IterativeProcess(
        it.initialize,
        computation_wrapper_instances.building_block_to_computation(lam))
    with self.assertRaises(TypeError):
      canonical_form_utils.get_canonical_form_for_iterative_process(bad_it)

  def test_broadcast_dependent_on_aggregate_fails_well(self):
    cf = test_utils.get_temperature_sensor_example()
    it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
    next_comp = test_utils.computation_to_building_block(it.next)
    top_level_param = building_blocks.Reference(next_comp.parameter_name,
                                                next_comp.parameter_type)
    first_result = building_blocks.Call(next_comp, top_level_param)
    middle_param = building_blocks.Tuple([
        building_blocks.Selection(first_result, index=0),
        building_blocks.Selection(top_level_param, index=1)
    ])
    second_result = building_blocks.Call(next_comp, middle_param)
    not_reducible = building_blocks.Lambda(next_comp.parameter_name,
                                           next_comp.parameter_type,
                                           second_result)
    not_reducible_it = iterative_process.IterativeProcess(
        it.initialize,
        computation_wrapper_instances.building_block_to_computation(
            not_reducible))

    with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'):
      canonical_form_utils.get_canonical_form_for_iterative_process(
          not_reducible_it)

  def test_constructs_canonical_form_from_mnist_training_example(self):
    it = canonical_form_utils.get_iterative_process_for_canonical_form(
        test_utils.get_mnist_training_example())
    cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
    self.assertIsInstance(cf, canonical_form.CanonicalForm)

  def test_temperature_example_round_trip(self):
    it = canonical_form_utils.get_iterative_process_for_canonical_form(
        test_utils.get_temperature_sensor_example())
    cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
    new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
    state = new_it.initialize()
    self.assertLen(state, 1)
    self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
    self.assertEqual(state[0], 0)

    state, metrics, stats = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
    self.assertLen(state, 1)
    self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
    self.assertEqual(state[0], 1)
    self.assertLen(metrics, 1)
    self.assertAllEqual(
        anonymous_tuple.name_list(metrics), ['ratio_over_threshold'])
    self.assertEqual(metrics[0], 0.5)
    self.assertCountEqual([self.evaluate(x.num_readings) for x in stats],
                          [1, 3])

    state, metrics, stats = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
    self.assertAllEqual(state, (2,))
    self.assertAllClose(metrics, {'ratio_over_threshold': 0.75})
    self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1])
    self.assertEqual(
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(it.next)),
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(new_it.next)))

  def test_mnist_training_round_trip(self):
    it = canonical_form_utils.get_iterative_process_for_canonical_form(
        test_utils.get_mnist_training_example())
    cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
    new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
    state1 = it.initialize()
    state2 = new_it.initialize()
    self.assertEqual(str(state1), str(state2))
    dummy_x = np.array([[0.5] * 784], dtype=np.float32)
    dummy_y = np.array([1], dtype=np.int32)
    client_data = [collections.OrderedDict(x=dummy_x, y=dummy_y)]
    round_1 = it.next(state1, [client_data])
    state = round_1[0]
    metrics = round_1[1]
    alt_round_1 = new_it.next(state2, [client_data])
    alt_state = alt_round_1[0]
    alt_metrics = alt_round_1[1]
    self.assertAllEqual(
        anonymous_tuple.name_list(state), anonymous_tuple.name_list(alt_state))
    self.assertAllEqual(
        anonymous_tuple.name_list(metrics),
        anonymous_tuple.name_list(alt_metrics))
    self.assertAllClose(state, alt_state)
    self.assertAllClose(metrics, alt_metrics)
    self.assertEqual(
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(it.next)),
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(new_it.next)))

  # pyformat: disable
  @parameterized.named_parameters(
      ('sum_example',
       get_iterative_process_for_sum_example()),
      ('sum_example_with_no_prepare',
       get_iterative_process_for_sum_example_with_no_prepare()),
      ('sum_example_with_no_broadcast',
       get_iterative_process_for_sum_example_with_no_broadcast()),
      ('sum_example_with_no_client_output',
       get_iterative_process_for_sum_example_with_no_client_output()),
      ('sum_example_with_no_federated_aggregate',
       get_iterative_process_for_sum_example_with_no_federated_aggregate()),
      ('sum_example_with_no_federated_secure_sum',
       get_iterative_process_for_sum_example_with_no_federated_secure_sum()),
      ('sum_example_with_no_update',
       get_iterative_process_for_sum_example_with_no_update()),
      ('sum_example_with_no_server_state',
       get_iterative_process_for_sum_example_with_no_server_state()),
      ('minimal_sum_example',
       get_iterative_process_for_minimal_sum_example()),
      ('example_with_unused_lambda_arg',
       test_utils.get_iterative_process_for_example_with_unused_lambda_arg()),
      ('example_with_unused_tf_computation_arg',
       test_utils.get_iterative_process_for_example_with_unused_tf_computation_arg()),
  )
  # pyformat: enable
  def test_returns_canonical_form(self, ip):
    cf = canonical_form_utils.get_canonical_form_for_iterative_process(ip)

    self.assertIsInstance(cf, canonical_form.CanonicalForm)

  def test_raises_value_error_for_sum_example_with_no_aggregation(self):
    ip = get_iterative_process_for_sum_example_with_no_aggregation()

    with self.assertRaises(ValueError):
      canonical_form_utils.get_canonical_form_for_iterative_process(ip)
Beispiel #3
0
class GetCanonicalFormForIterativeProcessTest(CanonicalFormTestCase,
                                              parameterized.TestCase):
    def test_next_computation_returning_tensor_fails_well(self):
        cf = test_utils.get_temperature_sensor_example()
        it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
        init_result = it.initialize.type_signature.result
        lam = building_blocks.Lambda(
            'x', init_result, building_blocks.Reference('x', init_result))
        bad_it = iterative_process.IterativeProcess(
            it.initialize,
            computation_wrapper_instances.building_block_to_computation(lam))
        with self.assertRaises(TypeError):
            canonical_form_utils.get_canonical_form_for_iterative_process(
                bad_it)

    def test_broadcast_dependent_on_aggregate_fails_well(self):
        cf = test_utils.get_temperature_sensor_example()
        it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
        next_comp = test_utils.computation_to_building_block(it.next)
        top_level_param = building_blocks.Reference(next_comp.parameter_name,
                                                    next_comp.parameter_type)
        first_result = building_blocks.Call(next_comp, top_level_param)
        middle_param = building_blocks.Tuple([
            building_blocks.Selection(first_result, index=0),
            building_blocks.Selection(top_level_param, index=1)
        ])
        second_result = building_blocks.Call(next_comp, middle_param)
        not_reducible = building_blocks.Lambda(next_comp.parameter_name,
                                               next_comp.parameter_type,
                                               second_result)
        not_reducible_it = iterative_process.IterativeProcess(
            it.initialize,
            computation_wrapper_instances.building_block_to_computation(
                not_reducible))

        with self.assertRaisesRegex(ValueError,
                                    'broadcast dependent on aggregate'):
            canonical_form_utils.get_canonical_form_for_iterative_process(
                not_reducible_it)

    def test_constructs_canonical_form_from_mnist_training_example(self):
        it = canonical_form_utils.get_iterative_process_for_canonical_form(
            test_utils.get_mnist_training_example())
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
        self.assertIsInstance(cf, canonical_form.CanonicalForm)

    def test_temperature_example_round_trip(self):
        it = canonical_form_utils.get_iterative_process_for_canonical_form(
            test_utils.get_temperature_sensor_example())
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
        new_it = canonical_form_utils.get_iterative_process_for_canonical_form(
            cf)
        state = new_it.initialize()
        self.assertLen(state, 1)
        self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
        self.assertEqual(state[0], 0)

        state, metrics, stats = new_it.next(state,
                                            [[28.0], [30.0, 33.0, 29.0]])
        self.assertLen(state, 1)
        self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
        self.assertEqual(state[0], 1)
        self.assertLen(metrics, 1)
        self.assertAllEqual(anonymous_tuple.name_list(metrics),
                            ['ratio_over_threshold'])
        self.assertEqual(metrics[0], 0.5)
        self.assertCountEqual([self.evaluate(x.num_readings) for x in stats],
                              [1, 3])

        state, metrics, stats = new_it.next(state,
                                            [[33.0], [34.0], [35.0], [36.0]])
        self.assertAllEqual(state, (2, ))
        self.assertAllClose(metrics, {'ratio_over_threshold': 0.75})
        self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1])
        self.assertEqual(
            tree_analysis.count_tensorflow_variables_under(
                test_utils.computation_to_building_block(it.next)),
            tree_analysis.count_tensorflow_variables_under(
                test_utils.computation_to_building_block(new_it.next)))

    def test_mnist_training_round_trip(self):
        it = canonical_form_utils.get_iterative_process_for_canonical_form(
            test_utils.get_mnist_training_example())
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
        new_it = canonical_form_utils.get_iterative_process_for_canonical_form(
            cf)
        state1 = it.initialize()
        state2 = new_it.initialize()
        self.assertEqual(str(state1), str(state2))
        dummy_x = np.array([[0.5] * 784], dtype=np.float32)
        dummy_y = np.array([1], dtype=np.int32)
        client_data = [collections.OrderedDict(x=dummy_x, y=dummy_y)]
        round_1 = it.next(state1, [client_data])
        state = round_1[0]
        metrics = round_1[1]
        alt_round_1 = new_it.next(state2, [client_data])
        alt_state = alt_round_1[0]
        alt_metrics = alt_round_1[1]
        self.assertAllEqual(anonymous_tuple.name_list(state),
                            anonymous_tuple.name_list(alt_state))
        self.assertAllEqual(anonymous_tuple.name_list(metrics),
                            anonymous_tuple.name_list(alt_metrics))
        self.assertAllClose(state, alt_state)
        self.assertAllClose(metrics, alt_metrics)
        self.assertEqual(
            tree_analysis.count_tensorflow_variables_under(
                test_utils.computation_to_building_block(it.next)),
            tree_analysis.count_tensorflow_variables_under(
                test_utils.computation_to_building_block(new_it.next)))

    def test_canonical_form_from_tff_learning_structure_type_spec(self):
        it = test_utils.construct_example_training_comp()
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)

        work_type_spec = cf.work.type_signature

        # This type spec test actually carries the meaning that TFF's vanilla path
        # to canonical form will broadcast and aggregate exactly one copy of the
        # parameters. So the type test below in fact functions as a regression test
        # for the TFF compiler pipeline.
        # pyformat: disable
        expected_type_string = '(<<x=float32[?,2],y=int32[?,1]>*,<<trainable=<float32[2,1],float32[1]>,non_trainable=<>>>> -> <<<<<float32[2,1],float32[1]>,float32>,<float32,float32>,<float32,float32>,<float32>>,<>>,<>>)'
        # pyformat: enable
        self.assertEqual(work_type_spec.compact_representation(),
                         expected_type_string)

    def test_returns_canonical_form_from_tff_learning_structure(self):
        it = test_utils.construct_example_training_comp()
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
        new_it = canonical_form_utils.get_iterative_process_for_canonical_form(
            cf)
        self.assertIsInstance(cf, canonical_form.CanonicalForm)
        self.assertEqual(it.initialize.type_signature,
                         new_it.initialize.type_signature)
        # Notice next type_signatures need not be equal, since we may have appended
        # an empty tuple as client side-channel outputs if none existed
        self.assertEqual(it.next.type_signature.parameter,
                         new_it.next.type_signature.parameter)
        self.assertEqual(it.next.type_signature.result[0],
                         new_it.next.type_signature.result[0])
        self.assertEqual(it.next.type_signature.result[1],
                         new_it.next.type_signature.result[1])

        state1 = it.initialize()
        state2 = new_it.initialize()

        sample_batch = collections.OrderedDict(x=np.array([[1., 1.]],
                                                          dtype=np.float32),
                                               y=np.array([[0]],
                                                          dtype=np.int32))
        client_data = [sample_batch]

        round_1 = it.next(state1, [client_data])
        state = round_1[0]
        state_names = anonymous_tuple.name_list(state)
        state_arrays = anonymous_tuple.flatten(state)
        metrics = round_1[1]
        metrics_names = [x[0] for x in anonymous_tuple.iter_elements(metrics)]
        metrics_arrays = anonymous_tuple.flatten(metrics)

        alt_round_1 = new_it.next(state2, [client_data])
        alt_state = alt_round_1[0]
        alt_state_names = anonymous_tuple.name_list(alt_state)
        alt_state_arrays = anonymous_tuple.flatten(alt_state)
        alt_metrics = alt_round_1[1]
        alt_metrics_names = [
            x[0] for x in anonymous_tuple.iter_elements(alt_metrics)
        ]
        alt_metrics_arrays = anonymous_tuple.flatten(alt_metrics)

        self.assertEmpty(state.delta_aggregate_state)
        self.assertEmpty(state.model_broadcast_state)
        self.assertAllEqual(state_names, alt_state_names)
        self.assertAllEqual(metrics_names, alt_metrics_names)
        self.assertAllClose(state_arrays, alt_state_arrays)
        self.assertAllClose(metrics_arrays[:2], alt_metrics_arrays[:2])
        # Final metric is execution time
        self.assertAlmostEqual(metrics_arrays[2],
                               alt_metrics_arrays[2],
                               delta=1e-3)

    # pyformat: disable
    @parameterized.named_parameters(
        ('sum_example', get_iterative_process_for_sum_example()),
        ('sum_example_with_no_prepare',
         get_iterative_process_for_sum_example_with_no_prepare()),
        ('sum_example_with_no_broadcast',
         get_iterative_process_for_sum_example_with_no_broadcast()),
        ('sum_example_with_no_client_output',
         get_iterative_process_for_sum_example_with_no_client_output()),
        ('sum_example_with_no_federated_aggregate',
         get_iterative_process_for_sum_example_with_no_federated_aggregate()),
        ('sum_example_with_no_federated_secure_sum',
         get_iterative_process_for_sum_example_with_no_federated_secure_sum()),
        ('sum_example_with_no_update',
         get_iterative_process_for_sum_example_with_no_update()),
        ('sum_example_with_no_server_state',
         get_iterative_process_for_sum_example_with_no_server_state()),
        ('minimal_sum_example',
         get_iterative_process_for_minimal_sum_example()),
        ('example_with_unused_lambda_arg',
         test_utils.get_iterative_process_for_example_with_unused_lambda_arg()
         ),
        ('example_with_unused_tf_computation_arg',
         test_utils.
         get_iterative_process_for_example_with_unused_tf_computation_arg()),
    )
    # pyformat: enable
    def test_returns_canonical_form(self, ip):
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(ip)

        self.assertIsInstance(cf, canonical_form.CanonicalForm)

    def test_raises_value_error_for_sum_example_with_no_aggregation(self):
        ip = get_iterative_process_for_sum_example_with_no_aggregation()

        with self.assertRaises(ValueError):
            canonical_form_utils.get_canonical_form_for_iterative_process(ip)
Beispiel #4
0
class GetCanonicalFormForIterativeProcessTest(CanonicalFormTestCase,
                                              parameterized.TestCase):
    def test_next_computation_returning_tensor_fails_well(self):
        cf = test_utils.get_temperature_sensor_example()
        it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
        init_result = it.initialize.type_signature.result
        lam = building_blocks.Lambda(
            'x', init_result, building_blocks.Reference('x', init_result))
        bad_it = iterative_process.IterativeProcess(
            it.initialize,
            computation_wrapper_instances.building_block_to_computation(lam))
        with self.assertRaises(TypeError):
            canonical_form_utils.get_canonical_form_for_iterative_process(
                bad_it)

    def test_broadcast_dependent_on_aggregate_fails_well(self):
        cf = test_utils.get_temperature_sensor_example()
        it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
        next_comp = test_utils.computation_to_building_block(it.next)
        top_level_param = building_blocks.Reference(next_comp.parameter_name,
                                                    next_comp.parameter_type)
        first_result = building_blocks.Call(next_comp, top_level_param)
        middle_param = building_blocks.Struct([
            building_blocks.Selection(first_result, index=0),
            building_blocks.Selection(top_level_param, index=1)
        ])
        second_result = building_blocks.Call(next_comp, middle_param)
        not_reducible = building_blocks.Lambda(next_comp.parameter_name,
                                               next_comp.parameter_type,
                                               second_result)
        not_reducible_it = iterative_process.IterativeProcess(
            it.initialize,
            computation_wrapper_instances.building_block_to_computation(
                not_reducible))

        with self.assertRaisesRegex(ValueError,
                                    'broadcast dependent on aggregate'):
            canonical_form_utils.get_canonical_form_for_iterative_process(
                not_reducible_it)

    def test_constructs_canonical_form_from_mnist_training_example(self):
        it = canonical_form_utils.get_iterative_process_for_canonical_form(
            test_utils.get_mnist_training_example())
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
        self.assertIsInstance(cf, canonical_form.CanonicalForm)

    def test_temperature_example_round_trip(self):
        # NOTE: the roundtrip through CanonicalForm->IterProc->CanonicalForm seems
        # to lose the python container annotations on the StructType.
        it = canonical_form_utils.get_iterative_process_for_canonical_form(
            test_utils.get_temperature_sensor_example())
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
        new_it = canonical_form_utils.get_iterative_process_for_canonical_form(
            cf)
        state = new_it.initialize()
        self.assertEqual(state.num_rounds, 0)

        state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
        self.assertEqual(state.num_rounds, 1)
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.5))

        state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.75))
        self.assertEqual(
            tree_analysis.count_tensorflow_variables_under(
                test_utils.computation_to_building_block(it.next)),
            tree_analysis.count_tensorflow_variables_under(
                test_utils.computation_to_building_block(new_it.next)))

    def test_mnist_training_round_trip(self):
        it = canonical_form_utils.get_iterative_process_for_canonical_form(
            test_utils.get_mnist_training_example())
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
        new_it = canonical_form_utils.get_iterative_process_for_canonical_form(
            cf)
        state1 = it.initialize()
        state2 = new_it.initialize()
        self.assertAllClose(state1, state2)
        dummy_x = np.array([[0.5] * 784], dtype=np.float32)
        dummy_y = np.array([1], dtype=np.int32)
        client_data = [collections.OrderedDict(x=dummy_x, y=dummy_y)]
        round_1 = it.next(state1, [client_data])
        state = round_1[0]
        metrics = round_1[1]
        alt_round_1 = new_it.next(state2, [client_data])
        alt_state = alt_round_1[0]
        self.assertAllClose(state, alt_state)
        alt_metrics = alt_round_1[1]
        self.assertAllClose(metrics, alt_metrics)
        self.assertEqual(
            tree_analysis.count_tensorflow_variables_under(
                test_utils.computation_to_building_block(it.next)),
            tree_analysis.count_tensorflow_variables_under(
                test_utils.computation_to_building_block(new_it.next)))

    # pyformat: disable
    @parameterized.named_parameters(
        ('sum_example', get_iterative_process_for_sum_example()),
        ('sum_example_with_no_prepare',
         get_iterative_process_for_sum_example_with_no_prepare()),
        ('sum_example_with_no_broadcast',
         get_iterative_process_for_sum_example_with_no_broadcast()),
        ('sum_example_with_no_federated_aggregate',
         get_iterative_process_for_sum_example_with_no_federated_aggregate()),
        ('sum_example_with_no_federated_secure_sum',
         get_iterative_process_for_sum_example_with_no_federated_secure_sum()),
        ('sum_example_with_no_update',
         get_iterative_process_for_sum_example_with_no_update()),
        ('sum_example_with_no_server_state',
         get_iterative_process_for_sum_example_with_no_server_state()),
        ('minimal_sum_example',
         get_iterative_process_for_minimal_sum_example()),
        ('example_with_unused_lambda_arg',
         test_utils.get_iterative_process_for_example_with_unused_lambda_arg()
         ),
        ('example_with_unused_tf_computation_arg',
         test_utils.
         get_iterative_process_for_example_with_unused_tf_computation_arg()),
    )
    # pyformat: enable
    def test_returns_canonical_form(self, ip):
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(ip)

        self.assertIsInstance(cf, canonical_form.CanonicalForm)

    # pyformat: disable
    @parameterized.named_parameters(
        ('sum_example', get_iterative_process_for_sum_example()),
        ('sum_example_with_no_prepare',
         get_iterative_process_for_sum_example_with_no_prepare()),
        ('sum_example_with_no_broadcast',
         get_iterative_process_for_sum_example_with_no_broadcast()),
        ('sum_example_with_no_federated_aggregate',
         get_iterative_process_for_sum_example_with_no_federated_aggregate()),
        ('sum_example_with_no_federated_secure_sum',
         get_iterative_process_for_sum_example_with_no_federated_secure_sum()),
        ('sum_example_with_no_update',
         get_iterative_process_for_sum_example_with_no_update()),
        ('sum_example_with_no_server_state',
         get_iterative_process_for_sum_example_with_no_server_state()),
        ('minimal_sum_example',
         get_iterative_process_for_minimal_sum_example()),
        ('example_with_unused_lambda_arg',
         test_utils.get_iterative_process_for_example_with_unused_lambda_arg()
         ),
        ('example_with_unused_tf_computation_arg',
         test_utils.
         get_iterative_process_for_example_with_unused_tf_computation_arg()),
    )
    # pyformat: enable
    def test_returns_canonical_form_with_grappler_disabled(self, ip):
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(
            ip, None)

        self.assertIsInstance(cf, canonical_form.CanonicalForm)

    def test_raises_value_error_for_sum_example_with_no_aggregation(self):
        ip = get_iterative_process_for_sum_example_with_no_aggregation()

        with self.assertRaisesRegex(
                ValueError,
                r'Expected .* containing at least one `federated_aggregate` or '
                r'`federated_secure_sum`'):
            canonical_form_utils.get_canonical_form_for_iterative_process(ip)

    def test_returns_canonical_form_with_indirection_to_intrinsic(self):
        self.skipTest('b/160865930')
        ip = test_utils.get_iterative_process_for_example_with_lambda_returning_aggregation(
        )

        cf = canonical_form_utils.get_canonical_form_for_iterative_process(ip)

        self.assertIsInstance(cf, canonical_form.CanonicalForm)