def test_raises_on_different_parameter_types(self):
   int_reference = building_blocks.Reference('x', tf.int32)
   int_lambda = building_blocks.Lambda('x', tf.int32, int_reference)
   float_reference = building_blocks.Reference('x', tf.float32)
   float_lambda = building_blocks.Lambda('x', tf.float32, float_reference)
   with self.assertRaises(TypeError):
     transformations.concatenate_function_outputs(int_lambda, float_lambda)
 def test_raises_on_non_lambda_args(self):
   reference = building_blocks.Reference('x', tf.int32)
   tff_lambda = building_blocks.Lambda('x', tf.int32, reference)
   with self.assertRaises(TypeError):
     transformations.concatenate_function_outputs(tff_lambda, reference)
   with self.assertRaises(TypeError):
     transformations.concatenate_function_outputs(reference, tff_lambda)
 def test_raises_on_non_unique_names(self):
   reference = building_blocks.Reference('x', tf.int32)
   good_lambda = building_blocks.Lambda('x', tf.int32, reference)
   bad_lambda = building_blocks.Lambda('x', tf.int32, good_lambda)
   with self.assertRaises(ValueError):
     transformations.concatenate_function_outputs(good_lambda, bad_lambda)
   with self.assertRaises(ValueError):
     transformations.concatenate_function_outputs(bad_lambda, good_lambda)
示例#4
0
def extract_work(before_aggregate, after_aggregate, canonical_form_types):
    """Converts `before_aggregate` and `after_aggregate` to `work`.

  Args:
    before_aggregate: The first result of splitting `after_broadcast` on
      `tff_framework.FEDERATED_AGGREGATE`.
    after_aggregate: The second result of splitting `after_broadcast` on
      `tff_framework.FEDERATED_AGGREGATE`.
    canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type
      signatures specified by the `tff.utils.IterativeProcess` we are compiling.

  Returns:
    `work` as specified by `canonical_form.CanonicalForm`, an instance of
    `tff_framework.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we fail to extract a
    `tff_framework.CompiledComputation`, or we extract one of the wrong type.
  """
    # See `get_iterative_process_for_canonical_form()` above for the meaning of
    # variable names used in the code below.
    c3_elements_in_before_aggregate_parameter = [[0, 1], [1]]
    c3_to_before_aggregate_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            before_aggregate,
            c3_elements_in_before_aggregate_parameter).result.function)
    c5_index_in_before_aggregate_result = 0
    c3_to_c5_computation = transformations.select_output_from_lambda(
        c3_to_before_aggregate_computation,
        c5_index_in_before_aggregate_result)
    c6_index_in_after_aggregate_result = 2
    after_aggregate_to_c6_computation = transformations.select_output_from_lambda(
        after_aggregate, c6_index_in_after_aggregate_result)
    c3_elements_in_after_aggregate_parameter = [[0, 0, 1], [0, 1]]
    c3_to_c6_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            after_aggregate_to_c6_computation,
            c3_elements_in_after_aggregate_parameter).result.function)
    c3_to_unzipped_c4_computation = transformations.concatenate_function_outputs(
        c3_to_c5_computation, c3_to_c6_computation)
    c3_to_c4_computation = tff_framework.Lambda(
        c3_to_unzipped_c4_computation.parameter_name,
        c3_to_unzipped_c4_computation.parameter_type,
        tff_framework.create_federated_zip(
            c3_to_unzipped_c4_computation.result))

    work = transformations.consolidate_and_extract_local_processing(
        c3_to_c4_computation)
    if not isinstance(work, tff_framework.CompiledComputation):
        raise transformations.CanonicalFormCompilationError(
            'Failed to extract a `tff_framework.CompiledComputation` from '
            'work, instead received a {} (of type {}).'.format(
                type(work), work.type_signature))
    if work.type_signature != canonical_form_types['work_type']:
        raise transformations.CanonicalFormCompilationError(
            'Extracted a TF block of the wrong type. Expected a function with type '
            '{}, but the type signature of the TF block was {}'.format(
                canonical_form_types['work_type'], work.type_signature))
    return work
示例#5
0
 def test_concatenates_identities(self):
     x_reference = building_blocks.Reference('x', tf.int32)
     x_lambda = building_blocks.Lambda('x', tf.int32, x_reference)
     y_reference = building_blocks.Reference('y', tf.int32)
     y_lambda = building_blocks.Lambda('y', tf.int32, y_reference)
     concatenated = mapreduce_transformations.concatenate_function_outputs(
         x_lambda, y_lambda)
     self.assertEqual(str(concatenated), '(_var1 -> <_var1,_var1>)')
示例#6
0
def _extract_work(before_aggregate, after_aggregate):
    """Extracts `work` from `before_aggregate` and `after_aggregate`.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` only. As a result, this function
  does not assert that `before_aggregate` or `after_aggregate` has the expected
  structure, the caller is expected to perform these checks before calling this
  function.

  Args:
    before_aggregate: The first result of splitting `after_broadcast` on
      aggregate intrinsics.
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.

  Returns:
    `work` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
    c3_elements_in_before_aggregate_parameter = [[0, 1], [1]]
    c3_to_before_aggregate_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            before_aggregate,
            c3_elements_in_before_aggregate_parameter).result.function)
    c6_index_in_before_aggregate_result = [[0, 0], [1, 0]]
    c3_to_c6_computation = transformations.select_output_from_lambda(
        c3_to_before_aggregate_computation,
        c6_index_in_before_aggregate_result)
    c8_index_in_after_aggregate_result = 2
    after_aggregate_to_c8_computation = transformations.select_output_from_lambda(
        after_aggregate, c8_index_in_after_aggregate_result)
    c3_elements_in_after_aggregate_parameter = [[0, 0, 1], [0, 1]]
    c3_to_c8_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            after_aggregate_to_c8_computation,
            c3_elements_in_after_aggregate_parameter).result.function)
    c3_to_unzipped_c4_computation = transformations.concatenate_function_outputs(
        c3_to_c6_computation, c3_to_c8_computation)
    c3_to_c4_computation = building_blocks.Lambda(
        c3_to_unzipped_c4_computation.parameter_name,
        c3_to_unzipped_c4_computation.parameter_type,
        building_block_factory.create_federated_zip(
            c3_to_unzipped_c4_computation.result))

    return transformations.consolidate_and_extract_local_processing(
        c3_to_c4_computation)
示例#7
0
def _extract_work(before_aggregate, after_aggregate):
  """Converts `before_aggregate` and `after_aggregate` to `work`.

  Args:
    before_aggregate: The first result of splitting `after_broadcast` on
      `intrinsic_defs.FEDERATED_AGGREGATE`.
    after_aggregate: The second result of splitting `after_broadcast` on
      `intrinsic_defs.FEDERATED_AGGREGATE`.

  Returns:
    `work` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
  # See `get_iterative_process_for_canonical_form()` above for the meaning of
  # variable names used in the code below.
  c3_elements_in_before_aggregate_parameter = [[0, 1], [1]]
  c3_to_before_aggregate_computation = (
      transformations.zip_selection_as_argument_to_lower_level_lambda(
          before_aggregate,
          c3_elements_in_before_aggregate_parameter).result.function)
  c5_index_in_before_aggregate_result = 0
  c3_to_c5_computation = transformations.select_output_from_lambda(
      c3_to_before_aggregate_computation, c5_index_in_before_aggregate_result)
  c6_index_in_after_aggregate_result = 2
  after_aggregate_to_c6_computation = transformations.select_output_from_lambda(
      after_aggregate, c6_index_in_after_aggregate_result)
  c3_elements_in_after_aggregate_parameter = [[0, 0, 1], [0, 1]]
  c3_to_c6_computation = (
      transformations.zip_selection_as_argument_to_lower_level_lambda(
          after_aggregate_to_c6_computation,
          c3_elements_in_after_aggregate_parameter).result.function)
  c3_to_unzipped_c4_computation = transformations.concatenate_function_outputs(
      c3_to_c5_computation, c3_to_c6_computation)
  c3_to_c4_computation = building_blocks.Lambda(
      c3_to_unzipped_c4_computation.parameter_name,
      c3_to_unzipped_c4_computation.parameter_type,
      building_block_factory.create_federated_zip(
          c3_to_unzipped_c4_computation.result))

  work = transformations.consolidate_and_extract_local_processing(
      c3_to_c4_computation)
  return work
  def test_parameters_are_mapped_together(self):
    x_reference = building_blocks.Reference('x', tf.int32)
    x_lambda = building_blocks.Lambda('x', tf.int32, x_reference)
    y_reference = building_blocks.Reference('y', tf.int32)
    y_lambda = building_blocks.Lambda('y', tf.int32, y_reference)
    concatenated = transformations.concatenate_function_outputs(
        x_lambda, y_lambda)
    parameter_name = concatenated.parameter_name

    def _raise_on_other_name_reference(comp):
      if isinstance(comp,
                    building_blocks.Reference) and comp.name != parameter_name:
        raise ValueError
      return comp, True

    tree_analysis.check_has_unique_names(concatenated)
    transformation_utils.transform_postorder(concatenated,
                                             _raise_on_other_name_reference)