Esempio n. 1
0
 def test_raises_on_non_federated_selection(self):
     lam = building_blocks.Lambda(
         'x', [tf.int32], building_blocks.Reference('x', [tf.int32]))
     with self.assertRaises(TypeError):
         mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
             lam, [[0]])
Esempio n. 2
0
 def test_raises_on_non_tuple_parameter(self):
     lam = building_blocks.Lambda('x', tf.int32,
                                  building_blocks.Reference('x', tf.int32))
     with self.assertRaises(TypeError):
         mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
             lam, [[0]])
Esempio n. 3
0
 def test_raises_on_selection_from_non_tuple(self):
     lam = building_blocks.Lambda(
         'x', [tf.int32], building_blocks.Reference('x', [tf.int32]))
     with self.assertRaisesRegex(TypeError, 'nonexistent index'):
         mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
             lam, [[0, 0]])
Esempio n. 4
0
 def test_raises_on_non_lambda_comp(self):
     ref = building_blocks.Reference('x', [tf.int32])
     with self.assertRaises(TypeError):
         mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
             ref, [0])
Esempio n. 5
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
             None, [0])
Esempio n. 6
0
def _extract_update(after_aggregate):
    """Extracts `update` from `after_aggregate`.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` only. As a result, this function
  does not assert that `after_aggregate` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.

  Returns:
    `update` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
    s7_elements_in_after_aggregate_result = [0, 1]
    s7_output_extracted = transformations.select_output_from_lambda(
        after_aggregate, s7_elements_in_after_aggregate_result)
    s7_output_zipped = building_blocks.Lambda(
        s7_output_extracted.parameter_name, s7_output_extracted.parameter_type,
        building_block_factory.create_federated_zip(
            s7_output_extracted.result))
    s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]]
    s6_to_s7_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            s7_output_zipped,
            s6_elements_in_after_aggregate_parameter).result.function)

    # TODO(b/148942011): The transformation
    # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting
    # from nested structures, therefore we need to pack the type signature
    # `<s1, s3, s4>` as `<s1, <s3, s4>>`.
    name_generator = building_block_factory.unique_name_generator(
        s6_to_s7_computation)

    pack_ref_name = next(name_generator)
    pack_ref_type = computation_types.NamedTupleType([
        s6_to_s7_computation.parameter_type.member[0],
        computation_types.NamedTupleType([
            s6_to_s7_computation.parameter_type.member[1],
            s6_to_s7_computation.parameter_type.member[2],
        ]),
    ])
    pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type)
    sel_s1 = building_blocks.Selection(pack_ref, index=0)
    sel = building_blocks.Selection(pack_ref, index=1)
    sel_s3 = building_blocks.Selection(sel, index=0)
    sel_s4 = building_blocks.Selection(sel, index=1)
    result = building_blocks.Tuple([sel_s1, sel_s3, sel_s4])
    pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature,
                                     result)
    ref_name = next(name_generator)
    ref_type = computation_types.FederatedType(pack_ref_type,
                                               placements.SERVER)
    ref = building_blocks.Reference(ref_name, ref_type)
    unpacked_args = building_block_factory.create_federated_map_or_apply(
        pack_fn, ref)
    call = building_blocks.Call(s6_to_s7_computation, unpacked_args)
    fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
    return transformations.consolidate_and_extract_local_processing(fn)
Esempio n. 7
0
 def test_raises_on_non_int_index(self):
   lam = building_blocks.Lambda(
       'x', [tf.int32], building_blocks.Reference('x', [('a', tf.int32)]))
   with self.assertRaises(TypeError):
     transformations.zip_selection_as_argument_to_lower_level_lambda(
         lam, [['a']])
Esempio n. 8
0
 def test_raises_on_selection_tuple(self):
   lam = building_blocks.Lambda('x', tf.int32,
                                building_blocks.Reference('x', [tf.int32]))
   with self.assertRaises(TypeError):
     transformations.zip_selection_as_argument_to_lower_level_lambda(lam, (0))