Exemplo n.º 1
0
 def test_binds_single_argument_to_lower_lambda(self):
     fed_at_clients = computation_types.FederatedType(
         tf.int32, placements.CLIENTS)
     fed_at_server = computation_types.FederatedType(
         tf.int32, placements.SERVER)
     tuple_of_federated_types = computation_types.NamedTupleType(
         [fed_at_clients, fed_at_server])
     lam = building_blocks.Lambda(
         'x', tuple_of_federated_types,
         building_blocks.Selection(building_blocks.Reference(
             'x', tuple_of_federated_types),
                                   index=0))
     zeroth_index_extracted = mapreduce_transformations.bind_single_selection_as_argument_to_lower_level_lambda(
         lam, 0)
     self.assertEqual(zeroth_index_extracted.type_signature,
                      lam.type_signature)
     self.assertIsInstance(zeroth_index_extracted, building_blocks.Lambda)
     self.assertIsInstance(zeroth_index_extracted.result,
                           building_blocks.Call)
     self.assertIsInstance(zeroth_index_extracted.result.function,
                           building_blocks.Lambda)
     self.assertRegex(str(zeroth_index_extracted.result.function),
                      r'\((.{4})1 -> (\1)1\)')
     self.assertEqual(str(zeroth_index_extracted.result.argument),
                      '_var1[0]')
def _extract_prepare(before_broadcast):
  """extracts `prepare` from `before_broadcast`.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` only. As a result, this function
  does not assert that `before_broadcast` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    before_broadcast: The first result of splitting `next_comp` on
      `intrinsic_defs.FEDERATED_BROADCAST`.

  Returns:
    `prepare` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
  s1_index_in_before_broadcast = 0
  s1_to_s2_computation = (
      transformations.bind_single_selection_as_argument_to_lower_level_lambda(
          before_broadcast, s1_index_in_before_broadcast)).result.function
  return transformations.consolidate_and_extract_local_processing(
      s1_to_s2_computation)
Exemplo n.º 3
0
def _extract_prepare(before_broadcast):
  """Converts `before_broadcast` into `prepare`.

  Args:
    before_broadcast: The first result of splitting `next_comp` on
      `intrinsic_defs.FEDERATED_BROADCAST`.

  Returns:
    `prepare` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
  # See `get_iterative_process_for_canonical_form()` above for the meaning of
  # variable names used in the code below.
  s1_index_in_before_broadcast = 0
  s1_to_s2_computation = (
      transformations.bind_single_selection_as_argument_to_lower_level_lambda(
          before_broadcast, s1_index_in_before_broadcast)).result.function
  prepare = transformations.consolidate_and_extract_local_processing(
      s1_to_s2_computation)
  return prepare
Exemplo n.º 4
0
def extract_prepare(before_broadcast, canonical_form_types):
  """Converts `before_broadcast` into `prepare`.

  Args:
    before_broadcast: The first result of splitting `next_comp` on
      `intrinsic_defs.FEDERATED_BROADCAST`.
    canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type
      signatures specified by the `tff.utils.IterativeProcess` we are compiling.

  Returns:
    `prepare` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we fail to extract a
    `building_blocks.CompiledComputation`, or we extract one of the wrong type.
  """
  # See `get_iterative_process_for_canonical_form()` above for the meaning of
  # variable names used in the code below.
  s1_index_in_before_broadcast = 0
  s1_to_s2_computation = (
      transformations.bind_single_selection_as_argument_to_lower_level_lambda(
          before_broadcast, s1_index_in_before_broadcast)).result.function
  prepare = transformations.consolidate_and_extract_local_processing(
      s1_to_s2_computation)
  if not isinstance(prepare, building_blocks.CompiledComputation):
    raise transformations.CanonicalFormCompilationError(
        'Failed to extract a `building_blocks.CompiledComputation` from '
        'prepare, instead received a {} (of type {}).'.format(
            type(prepare), prepare.type_signature))
  if prepare.type_signature != canonical_form_types['prepare_type']:
    raise transformations.CanonicalFormCompilationError(
        'Extracted a TF block of the wrong type. Expected a function with type '
        '{}, but the type signature of the TF block was {}'.format(
            canonical_form_types['prepare_type'], prepare.type_signature))
  return prepare