def test_removes_selected_intrinsic_leaving_remaining_intrinsic(self):
     federated_aggregate = building_block_test_utils.create_whimsy_called_federated_aggregate(
         accumulate_parameter_name='a',
         merge_parameter_name='b',
         report_parameter_name='c')
     federated_secure_sum_bitwidth = building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
     )
     called_intrinsics = building_blocks.Struct([
         federated_aggregate,
         federated_secure_sum_bitwidth,
     ])
     comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
     null_aggregate = building_block_factory.create_null_federated_aggregate(
     )
     secure_sum_bitwidth_uri = federated_secure_sum_bitwidth.function.uri
     aggregate_uri = null_aggregate.function.uri
     before, after = transformations.force_align_and_split_by_intrinsics(
         comp, [null_aggregate])
     self.assertTrue(
         tree_analysis.contains_called_intrinsic(comp,
                                                 secure_sum_bitwidth_uri))
     self.assertTrue(
         tree_analysis.contains_called_intrinsic(comp, aggregate_uri))
     self.assertFalse(
         tree_analysis.contains_called_intrinsic(before, aggregate_uri))
     self.assertFalse(
         tree_analysis.contains_called_intrinsic(after, aggregate_uri))
     self.assertTrue(
         tree_analysis.contains_called_intrinsic(before,
                                                 secure_sum_bitwidth_uri)
         or tree_analysis.contains_called_intrinsic(
             after, secure_sum_bitwidth_uri))
 def test_splits_on_selected_intrinsic_aggregate(self):
     federated_aggregate = building_block_test_utils.create_whimsy_called_federated_aggregate(
         accumulate_parameter_name='a',
         merge_parameter_name='b',
         report_parameter_name='c')
     called_intrinsics = building_blocks.Struct([federated_aggregate])
     comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
     call = building_block_factory.create_null_federated_aggregate()
     self.assert_splits_on(comp, call)
 def test_splits_even_when_selected_intrinsic_is_not_present(self):
   federated_aggregate = compiler_test_utils.create_whimsy_called_federated_aggregate(
       accumulate_parameter_name='a',
       merge_parameter_name='b',
       report_parameter_name='c')
   called_intrinsics = building_blocks.Struct([federated_aggregate])
   comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
   transformations.force_align_and_split_by_intrinsics(comp, [
       building_block_factory.create_null_federated_aggregate(),
       building_block_factory.create_null_federated_secure_sum_bitwidth(),
   ])
 def test_splits_on_two_intrinsics(self):
     federated_aggregate = building_block_test_utils.create_whimsy_called_federated_aggregate(
         accumulate_parameter_name='a',
         merge_parameter_name='b',
         report_parameter_name='c')
     federated_secure_sum_bitwidth = building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
     )
     called_intrinsics = building_blocks.Struct([
         federated_aggregate,
         federated_secure_sum_bitwidth,
     ])
     comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
     self.assert_splits_on(comp, [
         building_block_factory.create_null_federated_aggregate(),
         building_block_factory.create_null_federated_secure_sum_bitwidth()
     ])
def _split_ast_on_aggregate(bb):
    """Splits an AST on reduced aggregation intrinsics.

  Args:
    bb: An AST containing `federated_aggregate` or
      `federated_secure_sum_bitwidth` aggregations.

  Returns:
    Two ASTs, the first of which maps comp's input to the arguments
    to `federated_aggregate` and `federated_secure_sum_bitwidth`, and the
    second of which maps comp's input and the output of `federated_aggregate`
    and `federated_secure_sum_bitwidth` to comp's output.
  """
    return transformations.force_align_and_split_by_intrinsics(
        bb, [
            building_block_factory.create_null_federated_aggregate(),
            building_block_factory.create_null_federated_secure_sum_bitwidth()
        ])
def compile_to_mergeable_comp_form(
    comp: computation_impl.ConcreteComputation
) -> mergeable_comp_execution_context.MergeableCompForm:
    """Compiles a computation with a single aggregation to `MergeableCompForm`.

  Compilation proceeds by splitting on the lone aggregation, and using the
  aggregation's internal functions to generate a semantically equivalent
  instance of `mergeable_comp_execution_context.MergeableCompForm`.

  Args:
    comp: Instance of `computation_impl.ConcreteComputation` to compile. Assumed
      to be representable as a computation with a single aggregation in its
      body, so that for example two parallel aggregations are allowed, but
      multiple dependent aggregations are disallowed. Additionally assumed to be
      of functional type.

  Returns:
    A semantically equivalent instance of
    `mergeable_comp_execution_context.MergeableCompForm`.

  Raises:
    TypeError: If `comp` is not a building block, or is not of functional TFF
      type.
    ValueError: If `comp` cannot be represented as a computation with at most
    one aggregation in its body.
  """
    original_return_type = comp.type_signature.result
    building_block = comp.to_building_block()
    lam = _ensure_lambda(building_block)
    lowered_bb, _ = tree_transformations.replace_intrinsics_with_bodies(lam)

    # We transform the body of this computation to easily preserve the top-level
    # lambda required by force-aligning.
    call_dominant_body_bb = transformations.to_call_dominant(lowered_bb.result)
    call_dominant_bb = building_blocks.Lambda(lowered_bb.parameter_name,
                                              lowered_bb.parameter_type,
                                              call_dominant_body_bb)

    # This check should not throw false positives because we just ensured we are
    # in call-dominant form.
    tree_analysis.check_aggregate_not_dependent_on_aggregate(call_dominant_bb)

    before_agg, after_agg = transformations.force_align_and_split_by_intrinsics(
        call_dominant_bb,
        [building_block_factory.create_null_federated_aggregate()])

    # Construct a report function which accepts the result of merge.
    merge_fn_type = before_agg.type_signature.result[
        'federated_aggregate_param'][3]
    identity_report = computation_impl.ConcreteComputation.from_building_block(
        building_block_factory.create_compiled_identity(merge_fn_type.result))

    zero_comp, accumulate_comp, merge_comp, report_comp = _extract_federated_aggregate_computations(
        before_agg)

    before_agg_callable = computation_impl.ConcreteComputation.from_building_block(
        before_agg)
    after_agg_callable = computation_impl.ConcreteComputation.from_building_block(
        after_agg)

    if before_agg.type_signature.parameter is not None:
        # TODO(b/147499373): If None-arguments were uniformly represented as empty
        # tuples, we would be able to avoid this (and related) ugly casing.

        @federated_computation.federated_computation(
            before_agg.type_signature.parameter)
        def up_to_merge_computation(arg):
            federated_aggregate_args = before_agg_callable(
                arg)['federated_aggregate_param']
            value_to_aggregate = federated_aggregate_args[0]
            zero = zero_comp()
            return intrinsics.federated_aggregate(value_to_aggregate, zero,
                                                  accumulate_comp, merge_comp,
                                                  identity_report)

        @federated_computation.federated_computation(
            before_agg.type_signature.parameter,
            computation_types.at_server(identity_report.type_signature.result))
        def after_merge_computation(top_level_arg, merge_result):
            reported_result = intrinsics.federated_map(report_comp,
                                                       merge_result)
            return after_agg_callable(top_level_arg, [reported_result])

    else:

        @federated_computation.federated_computation()
        def up_to_merge_computation():
            federated_aggregate_args = before_agg_callable(
            )['federated_aggregate_param']
            value_to_aggregate = federated_aggregate_args[0]
            zero = zero_comp()
            return intrinsics.federated_aggregate(value_to_aggregate, zero,
                                                  accumulate_comp, merge_comp,
                                                  identity_report)

        @federated_computation.federated_computation(
            computation_types.at_server(identity_report.type_signature.result))
        def after_merge_computation(merge_result):
            reported_result = intrinsics.federated_map(report_comp,
                                                       merge_result)
            return after_agg_callable([[reported_result]])

    annotated_type_signature = computation_types.FunctionType(
        after_merge_computation.type_signature.parameter, original_return_type)
    after_merge_computation = computation_impl.ConcreteComputation.with_type(
        after_merge_computation, annotated_type_signature)

    return mergeable_comp_execution_context.MergeableCompForm(
        up_to_merge=up_to_merge_computation,
        merge=merge_comp,
        after_merge=after_merge_computation)