Esempio n. 1
0
def check_contains_no_new_unbound_references(old_tree, new_tree):
    """Checks that `new_tree` contains no unbound references not in `old_tree`."""
    old_unbound = transformation_utils.get_map_of_unbound_references(
        old_tree)[old_tree]
    new_unbound = transformation_utils.get_map_of_unbound_references(
        new_tree)[new_tree]
    diff = new_unbound - old_unbound
    if diff:
        raise ValueError('Expected no new unbounded references. '
                         f'Old tree:\n{old_tree}\nNew tree:\n{new_tree}\n'
                         f'New unbound references: {diff}')
Esempio n. 2
0
def _get_unbound_ref(block):
  """Helper to get unbound ref name and type spec if it exists in `block`."""
  all_unbound_refs = transformation_utils.get_map_of_unbound_references(block)
  top_level_unbound_ref = all_unbound_refs[block]
  num_unbound_refs = len(top_level_unbound_ref)
  if num_unbound_refs == 0:
    return None
  elif num_unbound_refs > 1:
    raise ValueError('`create_tensorflow_representing_block` must be passed '
                     'a block with at most a single unbound reference; '
                     'encountered the block {} with {} unbound '
                     'references.'.format(block, len(top_level_unbound_ref)))

  unbound_ref_name = top_level_unbound_ref.pop()

  top_level_type_spec = None

  def _get_unbound_ref_type_spec(inner_comp):
    if (inner_comp.is_reference() and inner_comp.name == unbound_ref_name):
      nonlocal top_level_type_spec
      top_level_type_spec = inner_comp.type_signature
    return inner_comp, False

  transformation_utils.transform_postorder(block, _get_unbound_ref_type_spec)
  return building_blocks.Reference(unbound_ref_name, top_level_type_spec)
Esempio n. 3
0
def check_has_unique_names(comp):
  """Checks that each variable of `comp` is bound at most once.

  Additionally, checks that `comp` does not mask any names which are unbound
  at the top level.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock`.

  Raises:
    NonuniqueNameError: If we encounter a name that is bound multiple times or a
      binding which would shadow an unbound reference.
  """
  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
  # Initializing `names` to unbound names in `comp` ensures that `comp` does not
  # mask any names from its parent scope.
  names = transformation_utils.get_map_of_unbound_references(comp)[comp]

  def _visit_name(name):
    if name in names:
      raise NonuniqueNameError(comp, name)
    names.add(name)

  def _visit(comp):
    if comp.is_block():
      for name, _ in comp.locals:
        _visit_name(name)
    elif comp.is_lambda() and comp.parameter_type is not None:
      _visit_name(comp.parameter_name)

  visit_postorder(comp, _visit)
 def transform(self, comp):
   if not self.should_transform(comp):
     return comp, False
   unbound_ref_set = transformation_utils.get_map_of_unbound_references(
       comp.result)[comp.result]
   if (not unbound_ref_set) or (not comp.locals):
     return comp.result, True
   new_locals = []
   for name, val in reversed(comp.locals):
     if name in unbound_ref_set:
       new_locals.append((name, val))
       unbound_ref_set = unbound_ref_set.union(
           transformation_utils.get_map_of_unbound_references(val)[val])
       unbound_ref_set.discard(name)
   if len(new_locals) == len(comp.locals):
     return comp, False
   elif not new_locals:
     return comp.result, True
   return building_blocks.Block(reversed(new_locals), comp.result), True
 def __call__(self, proto: pb.Computation) -> Set[str]:
   """Returns the names of any unbound references in `proto`."""
   py_typecheck.check_type(proto, pb.Computation)
   evaluated = self._evaluated_comps.get(_hash_proto(proto))
   if evaluated is not None:
     return evaluated
   tree = building_blocks.ComputationBuildingBlock.from_proto(proto)
   unbound_ref_map = transformation_utils.get_map_of_unbound_references(tree)
   self._evaluated_comps.update(
       {_hash_proto(k.proto): v for k, v in unbound_ref_map.items()})
   return unbound_ref_map[tree]
Esempio n. 6
0
def _inline_block_variables_required_to_align_intrinsics(comp, uri):
    """Inlines the variables required to align the intrinsic for the given `uri`.

  This function inlines only the block variables required to align an intrinsic,
  which is necessary because many transformations insert block variables that do
  not impact alignment and should not be inlined.

  Additionally, this function iteratively attempts to inline block variables a
  long as the intrinsic can not be extracted to the top level lambda. Meaning,
  that unbound references in variables that are inlined, will also be inlined.

  Args:
    comp: The `building_blocks.Lambda` to transform.
    uri: A Python `list` of URI of intrinsics.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    ValueError: If an there are unbound references, other than block variables,
      preventing an intrinsic with the given `uri` from being aligned.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(uri, list)
    for x in uri:
        py_typecheck.check_type(x, str)

    while not _can_extract_intrinsics_to_top_level_lambda(comp, uri):
        unbound_references = transformation_utils.get_map_of_unbound_references(
            comp)
        variable_names = set()
        intrinsics = _get_called_intrinsics(comp, uri)
        for intrinsic in intrinsics:
            names = unbound_references[intrinsic]
            names.discard(comp.parameter_name)
            variable_names.update(names)
        if not variable_names:
            raise tree_transformations.TransformationError(
                'Inlining `Block` variables has failed. Expected to find unbound '
                'references for called `Intrisic`s matching the URI: \'{}\', but '
                'none were found in the AST: \n{}'.format(
                    uri, comp.formatted_representation()))
        comp, modified = tree_transformations.inline_block_locals(
            comp, variable_names=variable_names)
        if modified:
            comp, _ = tree_transformations.uniquify_reference_names(comp)
        else:
            raise tree_transformations.TransformationError(
                'Inlining `Block` variables has failed, this will result in an '
                'infinite loop. Expected to modify the AST by inlining the variable '
                'names: \'{}\', but no transformations to the AST: \n{}'.
                format(variable_names, comp.formatted_representation()))
    return comp
 def test_binding_single_arg_leaves_no_unbound_references(self):
   fed_at_clients = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
   fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER)
   tuple_of_federated_types = computation_types.NamedTupleType(
       [fed_at_clients, fed_at_server])
   lam = building_blocks.Lambda(
       'x', tuple_of_federated_types,
       building_blocks.Selection(
           building_blocks.Reference('x', tuple_of_federated_types), index=0))
   zeroth_index_extracted = transformations.zip_selection_as_argument_to_lower_level_lambda(
       lam, [[0]])
   unbound_references = transformation_utils.get_map_of_unbound_references(
       zeroth_index_extracted)[zeroth_index_extracted]
   self.assertEmpty(unbound_references)
 def test_single_element_selection_leaves_no_unbound_references(self):
   fed_at_clients = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
   fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER)
   tuple_of_federated_types = computation_types.StructType(
       [fed_at_clients, fed_at_server])
   lam = building_blocks.Lambda(
       'x', tuple_of_federated_types,
       building_blocks.Selection(
           building_blocks.Reference('x', tuple_of_federated_types), index=0))
   new_lam = form_utils._as_function_of_some_federated_subparameters(
       lam, [(0,)])
   unbound_references = transformation_utils.get_map_of_unbound_references(
       new_lam)[new_lam]
   self.assertEmpty(unbound_references)
Esempio n. 9
0
 def should_transform(self, comp):
   if not (type_analysis.is_tensorflow_compatible_type(comp.type_signature) or
           (comp.type_signature.is_function() and
            type_analysis.is_tensorflow_compatible_type(
                comp.type_signature.parameter) and
            type_analysis.is_tensorflow_compatible_type(
                comp.type_signature.result))):
     return False
   elif comp.is_compiled_computation() or (
       comp.is_call() and comp.function.is_compiled_computation()):
     # These represent the final result of TF generation; no need to transform.
     return False
   unbound_refs = transformation_utils.get_map_of_unbound_references(
       comp)[comp]
   if unbound_refs:
     # We cannot represent these captures without further information.
     return False
   if tree_analysis.contains_types(comp, building_blocks.Intrinsic):
     return False
   return True
Esempio n. 10
0
def compile_local_subcomputations_to_tensorflow(
    comp: building_blocks.ComputationBuildingBlock,
) -> building_blocks.ComputationBuildingBlock:
    """Compiles subcomputations to TensorFlow where possible."""
    comp = unpack_compiled_computations(comp)
    local_cache = {}

    def _is_local(comp):
        cached = local_cache.get(comp, None)
        if cached is not None:
            return cached
        if (comp.is_intrinsic() or comp.is_data() or comp.is_placement() or
                type_analysis.contains_federated_types(comp.type_signature)):
            local_cache[comp] = False
            return False
        if (comp.is_compiled_computation()
                and comp.proto.WhichOneof('computation') == 'xla'):
            local_cache[comp] = False
            return False
        for child in comp.children():
            if not _is_local(child):
                local_cache[comp] = False
                return False
        return True

    unbound_ref_map = transformation_utils.get_map_of_unbound_references(comp)

    def _compile_if_local(comp):
        if _is_local(comp) and not unbound_ref_map[comp]:
            return compile_local_computation_to_tensorflow(comp), True
        return comp, False

    # Note: this transformation is preorder so that local subcomputations are not
    # first transformed to TensorFlow if they have a parent local computation
    # which could have instead been transformed into a larger single block of
    # TensorFlow.
    comp, _ = transformation_utils.transform_preorder(comp, _compile_if_local)
    return comp
Esempio n. 11
0
def contains_no_unbound_references(tree, excluding=None):
  """Tests if all the references in `tree` are bound by `tree`.

  Args:
    tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an
      abstract syntax tree.
    excluding: A `string` or a collection of `string`s representing the names of
      references to exclude from the test.

  Returns:
    `True` if there are no unbound references in `tree` excluding those
    specified by `excluding`, otherwise `False`.
  """
  py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock)
  if isinstance(excluding, str):
    excluding = [excluding]
  unbound_references = transformation_utils.get_map_of_unbound_references(tree)
  if excluding is not None:
    excluding = set(excluding)
    names = unbound_references[tree] - excluding
  else:
    names = unbound_references[tree]
  return len(names) == 0  # pylint: disable=g-explicit-length-test
Esempio n. 12
0
    def test_returns_tree(self):
        ip = get_iterative_process_for_sum_example_with_no_federated_secure_sum(
        )
        next_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.next._computation_proto)
        next_tree = canonical_form_utils._replace_intrinsics_with_bodies(
            next_tree)

        before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_secure_sum(
            next_tree)

        before_federated_aggregate, after_federated_aggregate = (
            transformations.force_align_and_split_by_intrinsics(
                next_tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))
        self.assertIsInstance(before_aggregate, building_blocks.Lambda)
        self.assertIsInstance(before_aggregate.result, building_blocks.Struct)
        self.assertLen(before_aggregate.result, 2)

        # trees_equal will fail if computations refer to unbound references, so we
        # create a new dummy computation to bind them.
        unbound_refs_in_before_agg_result = transformation_utils.get_map_of_unbound_references(
            before_aggregate.result[0])[before_aggregate.result[0]]
        unbound_refs_in_before_fed_agg_result = transformation_utils.get_map_of_unbound_references(
            before_federated_aggregate.result)[
                before_federated_aggregate.result]

        dummy_data = building_blocks.Data('data',
                                          computation_types.AbstractType('T'))

        blk_binding_refs_in_before_agg = building_blocks.Block(
            [(name, dummy_data) for name in unbound_refs_in_before_agg_result],
            before_aggregate.result[0])
        blk_binding_refs_in_before_fed_agg = building_blocks.Block(
            [(name, dummy_data)
             for name in unbound_refs_in_before_fed_agg_result],
            before_federated_aggregate.result)

        self.assertTrue(
            tree_analysis.trees_equal(blk_binding_refs_in_before_agg,
                                      blk_binding_refs_in_before_fed_agg))

        # pyformat: disable
        self.assertEqual(
            before_aggregate.result[1].formatted_representation(), '<\n'
            '  federated_value_at_clients(<>),\n'
            '  <>\n'
            '>')
        # pyformat: enable

        self.assertIsInstance(after_aggregate, building_blocks.Lambda)
        self.assertIsInstance(after_aggregate.result, building_blocks.Call)

        self.assertTrue(
            tree_analysis.trees_equal(after_aggregate.result.function,
                                      after_federated_aggregate))

        # pyformat: disable
        self.assertEqual(
            after_aggregate.result.argument.formatted_representation(), '<\n'
            '  _var1[0],\n'
            '  _var1[1][0]\n'
            '>')