コード例 #1
0
 def test_passes_with_noarg_top_level_computation(self):
   up_to_merge = build_noarg_count_clients_computation()
   merge = build_whimsy_merge_computation(tf.int32)
   after_merge = build_whimsy_after_merge_computation(
       up_to_merge.type_signature.parameter, merge.type_signature.result)
   mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm(
       up_to_merge=up_to_merge, merge=merge, after_merge=after_merge)
   self.assertIsInstance(mergeable_comp_form,
                         mergeable_comp_execution_context.MergeableCompForm)
コード例 #2
0
  def test_passes_with_correct_signatures(self):
    up_to_merge = build_sum_client_arg_computation(
        computation_types.at_server(tf.int32),
        computation_types.at_clients(tf.int32))
    merge = build_whimsy_merge_computation(tf.int32)
    after_merge = build_whimsy_after_merge_computation(
        up_to_merge.type_signature.parameter, merge.type_signature.result)
    mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm(
        up_to_merge=up_to_merge, merge=merge, after_merge=after_merge)

    self.assertIsInstance(mergeable_comp_form,
                          mergeable_comp_execution_context.MergeableCompForm)
コード例 #3
0
  def test_raises_no_top_level_argument_in_after_agg(self):
    up_to_merge = build_sum_client_arg_computation(
        computation_types.at_server(tf.int32),
        computation_types.at_clients(tf.int32))

    merge = build_whimsy_merge_computation(tf.int32)

    @federated_computation.federated_computation(
        computation_types.at_server(merge.type_signature.result))
    def bad_after_merge(x):
      return x

    with self.assertRaises(computation_types.TypeNotAssignableError):
      mergeable_comp_execution_context.MergeableCompForm(
          up_to_merge=up_to_merge, merge=merge, after_merge=bad_after_merge)
  def test_raises_up_to_merge_returns_non_server_placed_result(self):

    @computations.federated_computation(computation_types.at_server(tf.int32))
    def bad_up_to_merge(x):
      # Returns non SERVER-placed result.
      return x, x

    merge = build_whimsy_merge_computation(tf.int32)

    after_merge = build_whimsy_after_merge_computation(
        bad_up_to_merge.type_signature.parameter, merge.type_signature.result)

    with self.assertRaises(mergeable_comp_execution_context.UpToMergeTypeError):
      mergeable_comp_execution_context.MergeableCompForm(
          up_to_merge=bad_up_to_merge, merge=merge, after_merge=after_merge)
コード例 #5
0
  def test_raises_mismatched_up_to_merge_and_merge(self):
    up_to_merge = build_sum_client_arg_computation(
        computation_types.at_server(tf.int32),
        computation_types.at_clients(tf.int32))

    bad_merge = build_whimsy_merge_computation(tf.float32)

    @federated_computation.federated_computation(
        up_to_merge.type_signature.parameter,
        computation_types.at_server(bad_merge.type_signature.result))
    def after_merge(x, y):
      return (x, y)

    with self.assertRaises(
        mergeable_comp_execution_context.MergeTypeNotAssignableError):
      mergeable_comp_execution_context.MergeableCompForm(
          up_to_merge=up_to_merge, merge=bad_merge, after_merge=after_merge)
コード例 #6
0
  def test_computes_sum_of_all_values(self, arg, expected_sum):
    up_to_merge = build_sum_client_arg_computation(
        computation_types.at_server(tf.int32),
        computation_types.at_clients(tf.int32))
    merge = build_sum_merge_computation(tf.int32)
    after_merge = build_sum_merge_with_first_arg_computation(
        up_to_merge.type_signature.parameter, merge.type_signature.result)

    mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm(
        up_to_merge=up_to_merge, merge=merge, after_merge=after_merge)
    ex_factories = [
        python_executor_stacks.local_executor_factory() for _ in range(5)
    ]
    mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext(
        ex_factories)

    expected_result = type_conversions.type_to_py_container(
        expected_sum, after_merge.type_signature.result)
    result = mergeable_comp_context.invoke(mergeable_comp_form, arg)
    self.assertEqual(expected_result, result)
コード例 #7
0
  def test_raises_merge_computation_not_assignable_result(self):
    up_to_merge = build_sum_client_arg_computation(
        computation_types.at_server(tf.int32),
        computation_types.at_clients(tf.int32))

    @tensorflow_computation.tf_computation(tf.int32, tf.int32)
    def bad_merge(x, y):
      del x, y  # Unused
      return 1.  # of type float.

    @federated_computation.federated_computation(
        up_to_merge.type_signature.parameter,
        computation_types.at_server(bad_merge.type_signature.result))
    def after_merge(x, y):
      return (x, y)

    with self.assertRaises(
        mergeable_comp_execution_context.MergeTypeNotAssignableError):
      mergeable_comp_execution_context.MergeableCompForm(
          up_to_merge=up_to_merge, merge=bad_merge, after_merge=after_merge)
コード例 #8
0
  def test_counts_clients_with_noarg_computation(self):
    num_clients = 100
    num_executors = 5
    up_to_merge = build_noarg_count_clients_computation()
    merge = build_sum_merge_computation(tf.int32)
    after_merge = build_return_merge_result_with_no_first_arg_computation(
        merge.type_signature.result)

    mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm(
        up_to_merge=up_to_merge, merge=merge, after_merge=after_merge)
    ex_factories = [
        python_executor_stacks.local_executor_factory(
            default_num_clients=int(num_clients / num_executors))
        for _ in range(num_executors)
    ]
    mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext(
        ex_factories)

    expected_result = num_clients
    result = mergeable_comp_context.invoke(mergeable_comp_form, None)
    self.assertEqual(result, expected_result)
コード例 #9
0
  def test_raises_with_aggregation_in_after_agg(self):
    up_to_merge = build_sum_client_arg_computation(
        computation_types.at_server(tf.int32),
        computation_types.at_clients(tf.int32))

    merge = build_whimsy_merge_computation(tf.int32)

    @federated_computation.federated_computation(
        up_to_merge.type_signature.parameter,
        computation_types.at_server(merge.type_signature.result))
    def after_merge_with_sum(original_arg, merged_arg):
      del merged_arg  # Unused
      # Second element in original arg is the clients-placed value.
      return intrinsics.federated_sum(original_arg[1])

    with self.assertRaisesRegex(
        mergeable_comp_execution_context.AfterMergeStructureError,
        'federated_sum'):
      mergeable_comp_execution_context.MergeableCompForm(
          up_to_merge=up_to_merge,
          merge=merge,
          after_merge=after_merge_with_sum)
コード例 #10
0
  def test_runs_computation_with_clients_placed_return_values(self, arg):
    up_to_merge = build_sum_client_arg_computation(
        computation_types.at_server(tf.int32),
        computation_types.at_clients(tf.int32))
    merge = build_whimsy_merge_computation(tf.int32)
    after_merge = build_whimsy_after_merge_computation(
        up_to_merge.type_signature.parameter, merge.type_signature.result)

    # Simply returns the original argument
    mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm(
        up_to_merge=up_to_merge, merge=merge, after_merge=after_merge)
    ex_factories = [
        python_executor_stacks.local_executor_factory() for _ in range(5)
    ]
    mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext(
        ex_factories)
    # We preemptively package as a struct to work around shortcircuiting in
    # type_to_py_container in a non-Struct argument case.
    arg = structure.Struct.unnamed(*arg)
    expected_result = type_conversions.type_to_py_container(
        arg, after_merge.type_signature.result)
    result = mergeable_comp_context.invoke(mergeable_comp_form, arg)

    self.assertEqual(result, expected_result)
コード例 #11
0
def compile_to_mergeable_comp_form(
    comp: computation_impl.ConcreteComputation
) -> mergeable_comp_execution_context.MergeableCompForm:
    """Compiles a computation with a single aggregation to `MergeableCompForm`.

  Compilation proceeds by splitting on the lone aggregation, and using the
  aggregation's internal functions to generate a semantically equivalent
  instance of `mergeable_comp_execution_context.MergeableCompForm`.

  Args:
    comp: Instance of `computation_impl.ConcreteComputation` to compile. Assumed
      to be representable as a computation with a single aggregation in its
      body, so that for example two parallel aggregations are allowed, but
      multiple dependent aggregations are disallowed. Additionally assumed to be
      of functional type.

  Returns:
    A semantically equivalent instance of
    `mergeable_comp_execution_context.MergeableCompForm`.

  Raises:
    TypeError: If `comp` is not a building block, or is not of functional TFF
      type.
    ValueError: If `comp` cannot be represented as a computation with at most
    one aggregation in its body.
  """
    original_return_type = comp.type_signature.result
    building_block = comp.to_building_block()
    lam = _ensure_lambda(building_block)
    lowered_bb, _ = tree_transformations.replace_intrinsics_with_bodies(lam)

    # We transform the body of this computation to easily preserve the top-level
    # lambda required by force-aligning.
    call_dominant_body_bb = transformations.to_call_dominant(lowered_bb.result)
    call_dominant_bb = building_blocks.Lambda(lowered_bb.parameter_name,
                                              lowered_bb.parameter_type,
                                              call_dominant_body_bb)

    # This check should not throw false positives because we just ensured we are
    # in call-dominant form.
    tree_analysis.check_aggregate_not_dependent_on_aggregate(call_dominant_bb)

    before_agg, after_agg = transformations.force_align_and_split_by_intrinsics(
        call_dominant_bb,
        [building_block_factory.create_null_federated_aggregate()])

    # Construct a report function which accepts the result of merge.
    merge_fn_type = before_agg.type_signature.result[
        'federated_aggregate_param'][3]
    identity_report = computation_impl.ConcreteComputation.from_building_block(
        building_block_factory.create_compiled_identity(merge_fn_type.result))

    zero_comp, accumulate_comp, merge_comp, report_comp = _extract_federated_aggregate_computations(
        before_agg)

    before_agg_callable = computation_impl.ConcreteComputation.from_building_block(
        before_agg)
    after_agg_callable = computation_impl.ConcreteComputation.from_building_block(
        after_agg)

    if before_agg.type_signature.parameter is not None:
        # TODO(b/147499373): If None-arguments were uniformly represented as empty
        # tuples, we would be able to avoid this (and related) ugly casing.

        @federated_computation.federated_computation(
            before_agg.type_signature.parameter)
        def up_to_merge_computation(arg):
            federated_aggregate_args = before_agg_callable(
                arg)['federated_aggregate_param']
            value_to_aggregate = federated_aggregate_args[0]
            zero = zero_comp()
            return intrinsics.federated_aggregate(value_to_aggregate, zero,
                                                  accumulate_comp, merge_comp,
                                                  identity_report)

        @federated_computation.federated_computation(
            before_agg.type_signature.parameter,
            computation_types.at_server(identity_report.type_signature.result))
        def after_merge_computation(top_level_arg, merge_result):
            reported_result = intrinsics.federated_map(report_comp,
                                                       merge_result)
            return after_agg_callable(top_level_arg, [reported_result])

    else:

        @federated_computation.federated_computation()
        def up_to_merge_computation():
            federated_aggregate_args = before_agg_callable(
            )['federated_aggregate_param']
            value_to_aggregate = federated_aggregate_args[0]
            zero = zero_comp()
            return intrinsics.federated_aggregate(value_to_aggregate, zero,
                                                  accumulate_comp, merge_comp,
                                                  identity_report)

        @federated_computation.federated_computation(
            computation_types.at_server(identity_report.type_signature.result))
        def after_merge_computation(merge_result):
            reported_result = intrinsics.federated_map(report_comp,
                                                       merge_result)
            return after_agg_callable([[reported_result]])

    annotated_type_signature = computation_types.FunctionType(
        after_merge_computation.type_signature.parameter, original_return_type)
    after_merge_computation = computation_impl.ConcreteComputation.with_type(
        after_merge_computation, annotated_type_signature)

    return mergeable_comp_execution_context.MergeableCompForm(
        up_to_merge=up_to_merge_computation,
        merge=merge_comp,
        after_merge=after_merge_computation)