Example #1
0
def _can_extract_intrinsics_to_top_level_lambda(comp, uri):
    """Tests if the intrinsic for the given `uri` can be extracted.

  This currently maps identically to: the called intrinsics we intend to hoist
  don't close over any intermediate variables. That is, any variables other than
  potentiall the top-level parameter the computation itself declares.

  Args:
    comp: The `building_blocks.Lambda` to test. The names of lambda parameters
      and block variables in `comp` must be unique.
    uri: A Python `list` of URI of intrinsics.

  Returns:
    `True` if the intrinsic can be extracted, otherwise `False`.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(uri, list)
    for x in uri:
        py_typecheck.check_type(x, str)
    tree_analysis.check_has_unique_names(comp)

    intrinsics = _get_called_intrinsics(comp, uri)
    return all(
        tree_analysis.contains_no_unbound_references(x, comp.parameter_name)
        for x in intrinsics)
Example #2
0
def _as_function_of_some_federated_subparameters(
    bb: building_blocks.Lambda,
    paths,
) -> building_blocks.Lambda:
    """Turns `x -> ...only uses parts of x...` into `parts_of_x -> ...`."""
    tree_analysis.check_has_unique_names(bb)
    bb = _prepare_for_rebinding(bb)
    name_generator = building_block_factory.unique_name_generator(bb)

    type_list = []
    int_paths = []
    for path in paths:
        selected_type = bb.parameter_type
        int_path = []
        for index in path:
            if not selected_type.is_struct():
                raise _ParameterSelectionError(path, bb)
            if isinstance(index, int):
                if index > len(selected_type):
                    raise _ParameterSelectionError(path, bb)
                int_path.append(index)
            else:
                py_typecheck.check_type(index, str)
                if not structure.has_field(selected_type, index):
                    raise _ParameterSelectionError(path, bb)
                int_path.append(
                    structure.name_to_index_map(selected_type)[index])
            selected_type = selected_type[index]
        if not selected_type.is_federated():
            raise _NonFederatedSelectionError(
                'Attempted to rebind references to parameter selection path '
                f'{path} from type {bb.parameter_type}, but the value at that path '
                f'was of non-federated type {selected_type}. Selections must all '
                f'be of federated type. Original AST:\n{bb}')
        int_paths.append(tuple(int_path))
        type_list.append(selected_type)

    placement = type_list[0].placement
    if not all(x.placement is placement for x in type_list):
        raise _MismatchedSelectionPlacementError(
            'In order to zip the argument to the lower-level lambda together, all '
            'selected arguments should be at the same placement. Your selections '
            f'have resulted in the list of types:\n{type_list}')

    zip_type = computation_types.FederatedType([x.member for x in type_list],
                                               placement=placement)
    ref_to_zip = building_blocks.Reference(next(name_generator), zip_type)
    path_to_replacement = {}
    for i, path in enumerate(int_paths):
        path_to_replacement[path] = _construct_selection_from_federated_tuple(
            ref_to_zip, i, name_generator)

    new_lambda_body = _replace_selections(bb.result, bb.parameter_name,
                                          path_to_replacement)
    lambda_with_zipped_param = building_blocks.Lambda(
        ref_to_zip.name, ref_to_zip.type_signature, new_lambda_body)
    tree_analysis.check_contains_no_new_unbound_references(
        bb, lambda_with_zipped_param)

    return lambda_with_zipped_param
Example #3
0
def _insert_comp_in_top_level_lambda(comp, name, comp_to_insert):
    """Inserts a computation into `comp` with the given `name`.

  Args:
    comp: The `building_blocks.Lambda` to transform. The names of lambda
      parameters and block variables in `comp` must be unique.
    name: The name to use.
    comp_to_insert: The `building_blocks.ComputationBuildingBlock` to insert.

  Returns:
    A new computation with the transformation applied or the original `comp`.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(name, str)
    py_typecheck.check_type(comp_to_insert,
                            building_blocks.ComputationBuildingBlock)
    tree_analysis.check_has_unique_names(comp)

    result = comp.result
    if result.is_block():
        variables = result.locals
        result = result.result
    else:
        variables = []
    variables.insert(0, (name, comp_to_insert))
    block = building_blocks.Block(variables, result)
    return building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                  block)
Example #4
0
 def test_raises_lambda_rebinding_of_block_variable(self):
     x_ref = building_blocks.Reference('x', tf.int32)
     lambda_1 = building_blocks.Lambda('x', tf.int32, x_ref)
     x_data = building_blocks.Data('x', tf.int32)
     single_block = building_blocks.Block([('x', x_data)], lambda_1)
     with self.assertRaises(tree_analysis.NonuniqueNameError):
         tree_analysis.check_has_unique_names(single_block)
Example #5
0
def _split_by_intrinsics_in_top_level_lambda(comp):
    """Splits by the intrinsics in the frist block local in the result of `comp`.

  This function splits `comp` into two computations `before` and `after` the
  called intrinsic or tuple of called intrinsics found as the first local in the
  `building_blocks.Block` returned by the top level lambda; and returns a Python
  tuple representing the pair of `before` and `after` computations.

  Args:
    comp: The `building_blocks.Lambda` to split.

  Returns:
    A pair of `building_blocks.ComputationBuildingBlock`s.

  Raises:
    ValueError: If the first local in the `building_blocks.Block` referenced by
      the top level lambda is not a called intrincs or a
      `building_blocks.Struct` of called intrinsics.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(comp.result, building_blocks.Block)
    tree_analysis.check_has_unique_names(comp)

    name_generator = building_block_factory.unique_name_generator(comp)

    name, first_local = comp.result.locals[0]
    if building_block_analysis.is_called_intrinsic(first_local):
        result = first_local.argument
    elif first_local.is_struct():
        elements = []
        for element in first_local:
            if not building_block_analysis.is_called_intrinsic(element):
                raise ValueError(
                    'Expected all the elements of the `building_blocks.Struct` to be '
                    'called intrinsics, but found: \n{}'.format(element))
            elements.append(element.argument)
        result = building_blocks.Struct(elements)
    else:
        raise ValueError(
            'Expected either a called intrinsic or a `building_blocks.Struct` of '
            'called intrinsics, but found: \n{}'.format(first_local))

    before = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                    result)

    ref_name = next(name_generator)
    ref_type = computation_types.StructType(
        (comp.parameter_type, first_local.type_signature))
    ref = building_blocks.Reference(ref_name, ref_type)
    sel_after_arg_1 = building_blocks.Selection(ref, index=0)
    sel_after_arg_2 = building_blocks.Selection(ref, index=1)

    variables = comp.result.locals
    variables[0] = (name, sel_after_arg_2)
    variables.insert(0, (comp.parameter_name, sel_after_arg_1))
    block = building_blocks.Block(variables, comp.result.result)
    after = building_blocks.Lambda(ref.name, ref.type_signature, block)
    return before, after
Example #6
0
def concatenate_function_outputs(first_function, second_function):
    """Constructs a new function concatenating the outputs of its arguments.

  Assumes that `first_function` and `second_function` already have unique
  names, and have declared parameters of the same type. The constructed
  function will bind its parameter to each of the parameters of
  `first_function` and `second_function`, and return the result of executing
  these functions in parallel and concatenating the outputs in a tuple.

  Args:
    first_function: Instance of `building_blocks.Lambda` whose result we wish to
      concatenate with the result of `second_function`.
    second_function: Instance of `building_blocks.Lambda` whose result we wish
      to concatenate with the result of `first_function`.

  Returns:
    A new instance of `building_blocks.Lambda` with unique names representing
    the computation described above.

  Raises:
    TypeError: If the arguments are not instances of `building_blocks.Lambda`,
    or declare parameters of different types.
  """

    py_typecheck.check_type(first_function, building_blocks.Lambda)
    py_typecheck.check_type(second_function, building_blocks.Lambda)
    tree_analysis.check_has_unique_names(first_function)
    tree_analysis.check_has_unique_names(second_function)

    if first_function.parameter_type != second_function.parameter_type:
        raise TypeError(
            'Must pass two functions which declare the same parameter '
            'type to `concatenate_function_outputs`; you have passed '
            'one function which declared a parameter of type {}, and '
            'another which declares a parameter of type {}'.format(
                first_function.type_signature, second_function.type_signature))

    def _rename_first_function_arg(comp):
        if comp.is_reference() and comp.name == first_function.parameter_name:
            if comp.type_signature != second_function.parameter_type:
                raise AssertionError('{}, {}'.format(
                    comp.type_signature, second_function.parameter_type))
            return building_blocks.Reference(second_function.parameter_name,
                                             comp.type_signature), True
        return comp, False

    first_function, _ = transformation_utils.transform_postorder(
        first_function, _rename_first_function_arg)

    concatenated_function = building_blocks.Lambda(
        second_function.parameter_name, second_function.parameter_type,
        building_blocks.Struct([first_function.result,
                                second_function.result]))

    renamed, _ = tree_transformations.uniquify_reference_names(
        concatenated_function)

    return renamed
Example #7
0
def _as_function_of_single_subparameter(bb: building_blocks.Lambda,
                                        index: int) -> building_blocks.Lambda:
    """Turns `x -> ...only uses x_i...` into `x_i -> ...only uses x_i`."""
    tree_analysis.check_has_unique_names(bb)
    bb = _prepare_for_rebinding(bb)
    new_name = next(building_block_factory.unique_name_generator(bb))
    new_ref = building_blocks.Reference(new_name,
                                        bb.type_signature.parameter[index])
    new_lambda_body = _replace_selections(bb.result, bb.parameter_name,
                                          {(index, ): new_ref})
    new_lambda = building_blocks.Lambda(new_ref.name, new_ref.type_signature,
                                        new_lambda_body)
    tree_analysis.check_contains_no_new_unbound_references(bb, new_lambda)
    return new_lambda
Example #8
0
    def test_single_level_block(self):
        ref = building_blocks.Reference('a', tf.int32)
        data = building_blocks.Data('data', tf.int32)
        block = building_blocks.Block((('a', data), ('a', ref), ('a', ref)),
                                      ref)

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            block)

        self.assertEqual(block.compact_representation(),
                         '(let a=data,a=a,a=a in a)')
        self.assertEqual(transformed_comp.compact_representation(),
                         '(let a=data,_var1=a,_var2=_var1 in _var2)')
        tree_analysis.check_has_unique_names(transformed_comp)
        self.assertTrue(modified)
Example #9
0
    def test_nested_blocks(self):
        x_ref = building_blocks.Reference('a', tf.int32)
        data = building_blocks.Data('data', tf.int32)
        block1 = building_blocks.Block([('a', data), ('a', x_ref)], x_ref)
        block2 = building_blocks.Block([('a', data), ('a', x_ref)], block1)

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            block2)

        self.assertEqual(block2.compact_representation(),
                         '(let a=data,a=a in (let a=data,a=a in a))')
        self.assertEqual(
            transformed_comp.compact_representation(),
            '(let a=data,_var1=a in (let _var2=data,_var3=_var2 in _var3))')
        tree_analysis.check_has_unique_names(transformed_comp)
        self.assertTrue(modified)
Example #10
0
    def test_nested_lambdas(self):
        data = building_blocks.Data('data', tf.int32)
        input1 = building_blocks.Reference('a', data.type_signature)
        first_level_call = building_blocks.Call(
            building_blocks.Lambda('a', input1.type_signature, input1), data)
        input2 = building_blocks.Reference('b',
                                           first_level_call.type_signature)
        second_level_call = building_blocks.Call(
            building_blocks.Lambda('b', input2.type_signature, input2),
            first_level_call)

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            second_level_call)

        self.assertEqual(transformed_comp.compact_representation(),
                         '(b -> b)((a -> a)(data))')
        tree_analysis.check_has_unique_names(transformed_comp)
        self.assertFalse(modified)
  def test_parameters_are_mapped_together(self):
    x_reference = building_blocks.Reference('x', tf.int32)
    x_lambda = building_blocks.Lambda('x', tf.int32, x_reference)
    y_reference = building_blocks.Reference('y', tf.int32)
    y_lambda = building_blocks.Lambda('y', tf.int32, y_reference)
    concatenated = transformations.concatenate_function_outputs(
        x_lambda, y_lambda)
    parameter_name = concatenated.parameter_name

    def _raise_on_other_name_reference(comp):
      if isinstance(comp,
                    building_blocks.Reference) and comp.name != parameter_name:
        raise ValueError
      return comp, True

    tree_analysis.check_has_unique_names(concatenated)
    transformation_utils.transform_postorder(concatenated,
                                             _raise_on_other_name_reference)
Example #12
0
 def test_binding_multiple_args_results_in_unique_names(self):
     fed_at_clients = computation_types.FederatedType(
         tf.int32, placements.CLIENTS)
     fed_at_server = computation_types.FederatedType(
         tf.int32, placements.SERVER)
     tuple_of_federated_types = computation_types.StructType(
         [[fed_at_clients], fed_at_server, [fed_at_clients]])
     first_selection = building_blocks.Selection(building_blocks.Selection(
         building_blocks.Reference('x', tuple_of_federated_types), index=0),
                                                 index=0)
     second_selection = building_blocks.Selection(building_blocks.Selection(
         building_blocks.Reference('x', tuple_of_federated_types), index=2),
                                                  index=0)
     lam = building_blocks.Lambda(
         'x', tuple_of_federated_types,
         building_blocks.Struct([first_selection, second_selection]))
     new_lam = form_utils._as_function_of_some_federated_subparameters(
         lam, [(0, 0), (2, 0)])
     tree_analysis.check_has_unique_names(new_lam)
Example #13
0
    def test_blocks_nested_inside_of_locals(self):
        data = building_blocks.Data('data', tf.int32)
        lower_block = building_blocks.Block([('a', data)], data)
        middle_block = building_blocks.Block([('a', lower_block)], data)
        higher_block = building_blocks.Block([('a', middle_block)], data)
        y_ref = building_blocks.Reference('a', tf.int32)
        lower_block_with_y_ref = building_blocks.Block([('a', y_ref)], data)
        middle_block_with_y_ref = building_blocks.Block(
            [('a', lower_block_with_y_ref)], data)
        higher_block_with_y_ref = building_blocks.Block(
            [('a', middle_block_with_y_ref)], data)
        multiple_bindings_highest_block = building_blocks.Block(
            [('a', higher_block),
             ('a', higher_block_with_y_ref)], higher_block_with_y_ref)

        transformed_comp = self.assert_transforms(
            multiple_bindings_highest_block,
            'uniquify_names_blocks_nested_inside_of_locals.expected')
        tree_analysis.check_has_unique_names(transformed_comp)
Example #14
0
 def test_binding_multiple_args_results_in_unique_names(self):
     fed_at_clients = computation_types.FederatedType(
         tf.int32, placements.CLIENTS)
     fed_at_server = computation_types.FederatedType(
         tf.int32, placements.SERVER)
     tuple_of_federated_types = computation_types.NamedTupleType(
         [[fed_at_clients], fed_at_server, [fed_at_clients]])
     first_selection = building_blocks.Selection(building_blocks.Selection(
         building_blocks.Reference('x', tuple_of_federated_types), index=0),
                                                 index=0)
     second_selection = building_blocks.Selection(building_blocks.Selection(
         building_blocks.Reference('x', tuple_of_federated_types), index=2),
                                                  index=0)
     lam = building_blocks.Lambda(
         'x', tuple_of_federated_types,
         building_blocks.Tuple([first_selection, second_selection]))
     deep_zeroth_index_extracted = mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
         lam, [[0, 0], [2, 0]])
     tree_analysis.check_has_unique_names(deep_zeroth_index_extracted)
Example #15
0
    def test_block_lambda_block_lambda(self):
        x_ref = building_blocks.Reference('a', tf.int32)
        inner_lambda = building_blocks.Lambda('a', tf.int32, x_ref)
        called_lambda = building_blocks.Call(inner_lambda, x_ref)
        lower_block = building_blocks.Block([('a', x_ref), ('a', x_ref)],
                                            called_lambda)
        second_lambda = building_blocks.Lambda('a', tf.int32, lower_block)
        second_call = building_blocks.Call(second_lambda, x_ref)
        data = building_blocks.Data('data', tf.int32)
        last_block = building_blocks.Block([('a', data), ('a', x_ref)],
                                           second_call)

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            last_block)

        self.assertEqual(
            last_block.compact_representation(),
            '(let a=data,a=a in (a -> (let a=a,a=a in (a -> a)(a)))(a))')
        self.assertEqual(
            transformed_comp.compact_representation(),
            '(let a=data,_var1=a in (_var2 -> (let _var3=_var2,_var4=_var3 in (_var5 -> _var5)(_var4)))(_var1))'
        )
        tree_analysis.check_has_unique_names(transformed_comp)
        self.assertTrue(modified)
Example #16
0
 def test_ok_on_nested_lambdas_with_different_variable_name(self):
     ref_to_x = building_blocks.Reference('x', tf.int32)
     lambda_1 = building_blocks.Lambda('x', tf.int32, ref_to_x)
     lambda_2 = building_blocks.Lambda('y', tf.int32, lambda_1)
     tree_analysis.check_has_unique_names(lambda_2)
Example #17
0
 def test_raises_on_nested_lambdas_with_same_variable_name(self):
     ref_to_x = building_blocks.Reference('x', tf.int32)
     lambda_1 = building_blocks.Lambda('x', tf.int32, ref_to_x)
     lambda_2 = building_blocks.Lambda('x', tf.int32, lambda_1)
     with self.assertRaises(tree_analysis.NonuniqueNameError):
         tree_analysis.check_has_unique_names(lambda_2)
Example #18
0
 def test_ok_on_multiple_no_arg_lambdas(self):
     data = building_blocks.Data('x', tf.int32)
     lambda_1 = building_blocks.Lambda(None, None, data)
     lambda_2 = building_blocks.Lambda(None, None, data)
     tup = building_blocks.Struct([lambda_1, lambda_2])
     tree_analysis.check_has_unique_names(tup)
Example #19
0
 def test_ok_on_single_lambda(self):
     ref_to_x = building_blocks.Reference('x', tf.int32)
     lambda_1 = building_blocks.Lambda('x', tf.int32, ref_to_x)
     tree_analysis.check_has_unique_names(lambda_1)
Example #20
0
 def test_ok_on_sequential_binding_of_different_variable_in_block(self):
     x_data = building_blocks.Data('x', tf.int32)
     block = building_blocks.Block([('x', x_data), ('y', x_data)], x_data)
     tree_analysis.check_has_unique_names(block)
Example #21
0
def _group_by_intrinsics_in_top_level_lambda(comp):
    """Groups the intrinsics in the frist block local in the result of `comp`.

  This transformation creates an AST by replacing the tuple of called intrinsics
  found as the first local in the `building_blocks.Block` returned by the top
  level lambda with two new computations. The first computation is a tuple of
  tuples of called intrinsics, representing the original tuple of called
  intrinscis grouped by URI. The second computation is a tuple of selection from
  the first computations, representing original tuple of called intrinsics.

  It is necessary to group intrinsics before it is possible to merge them.

  Args:
    comp: The `building_blocks.Lambda` to transform.

  Returns:
    A `building_blocks.Lamda` that returns a `building_blocks.Block`, the first
    local variables of the retunred `building_blocks.Block` will be a tuple of
    tuples of called intrinsics representing the original tuple of called
    intrinscis grouped by URI.

  Raises:
    ValueError: If the first local in the `building_blocks.Block` referenced by
      the top level lambda is not a `building_blocks.Struct` of called
      intrinsics.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(comp.result, building_blocks.Block)
    tree_analysis.check_has_unique_names(comp)

    name_generator = building_block_factory.unique_name_generator(comp)

    name, first_local = comp.result.locals[0]
    py_typecheck.check_type(first_local, building_blocks.Struct)
    for element in first_local:
        if not building_block_analysis.is_called_intrinsic(element):
            raise ValueError(
                'Expected all the elements of the `building_blocks.Struct` to be '
                'called intrinsics, but found: \n{}'.format(element))

    # Create collections of data describing how to pack and unpack the intrinsics
    # into groups by their URI.
    #
    # packed_keys is a list of unique URI ordered by occurrence in the original
    #   tuple of called intrinsics.
    # packed_groups is a `collections.OrderedDict` where each key is a URI to
    #   group by and each value is a list of intrinsics with that URI.
    # packed_indexes is a list of tuples where each tuple contains two indexes:
    #   the first index in the tuple is the index of the group that the intrinsic
    #   was packed into; the second index in the tuple is the index of the
    #   intrinsic in that group that the intrinsic was packed into; the index of
    #   the tuple in packed_indexes corresponds to the index of the intrinsic in
    #   the list of intrinsics that are beging grouped. Therefore, packed_indexes
    #   represents an implicit mapping of packed indexes, keyed by unpacked index.
    packed_keys = []
    for called_intrinsic in first_local:
        uri = called_intrinsic.function.uri
        if uri not in packed_keys:
            packed_keys.append(uri)
    # If there are no duplicates, return early.
    if len(packed_keys) == len(first_local):
        return comp, False
    packed_groups = collections.OrderedDict([(x, []) for x in packed_keys])
    packed_indexes = []
    for called_intrinsic in first_local:
        packed_group = packed_groups[called_intrinsic.function.uri]
        packed_group.append(called_intrinsic)
        packed_indexes.append((
            packed_keys.index(called_intrinsic.function.uri),
            len(packed_group) - 1,
        ))

    packed_elements = []
    for called_intrinsics in packed_groups.values():
        if len(called_intrinsics) > 1:
            element = building_blocks.Struct(called_intrinsics)
        else:
            element = called_intrinsics[0]
        packed_elements.append(element)
    packed_comp = building_blocks.Struct(packed_elements)

    packed_ref_name = next(name_generator)
    packed_ref_type = computation_types.to_type(packed_comp.type_signature)
    packed_ref = building_blocks.Reference(packed_ref_name, packed_ref_type)

    unpacked_elements = []
    for indexes in packed_indexes:
        group_index = indexes[0]
        sel = building_blocks.Selection(packed_ref, index=group_index)
        uri = packed_keys[group_index]
        called_intrinsics = packed_groups[uri]
        if len(called_intrinsics) > 1:
            intrinsic_index = indexes[1]
            sel = building_blocks.Selection(sel, index=intrinsic_index)
        unpacked_elements.append(sel)
    unpacked_comp = building_blocks.Struct(unpacked_elements)

    variables = comp.result.locals
    variables[0] = (name, unpacked_comp)
    variables.insert(0, (packed_ref_name, packed_comp))
    block = building_blocks.Block(variables, comp.result.result)
    fn = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                block)
    return fn, True
Example #22
0
def _extract_intrinsics_to_top_level_lambda(comp, uri):
    r"""Extracts intrinsics in `comp` for the given `uri`.

  This transformation creates an AST such that all the called intrinsics for the
  given `uri` in body of the `building_blocks.Block` returned by the top level
  lambda have been extracted to the top level lambda and replaced by selections
  from a reference to the constructed variable.

                       Lambda
                       |
                       Block
                      /     \
        [x=Struct, ...]       Comp
           |
           [Call,                  Call                   Call]
           /    \                 /    \                 /    \
  Intrinsic      Comp    Intrinsic      Comp    Intrinsic      Comp

  The order of the extracted called intrinsics matches the order of `uri`.

  Note: if this function is passed an AST which contains nested called
  intrinsics, it will fail, as it will mutate the subcomputation containing
  the lower-level called intrinsics on the way back up the tree.

  Args:
    comp: The `building_blocks.Lambda` to transform. The names of lambda
      parameters and block variables in `comp` must be unique.
    uri: A URI of an intrinsic.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    ValueError: If all the intrinsics for the given `uri` in `comp` are not
      exclusively bound by `comp`.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(uri, list)
    for x in uri:
        py_typecheck.check_type(x, str)
    tree_analysis.check_has_unique_names(comp)

    name_generator = building_block_factory.unique_name_generator(comp)

    intrinsics = _get_called_intrinsics(comp, uri)
    for intrinsic in intrinsics:
        if not tree_analysis.contains_no_unbound_references(
                intrinsic, comp.parameter_name):
            raise ValueError(
                'Expected a computation which binds all the references in all the '
                'intrinsic with the uri: {}.'.format(uri))
    if len(intrinsics) > 1:
        order = {}
        for index, element in enumerate(uri):
            if element not in order:
                order[element] = index
        intrinsics = sorted(intrinsics, key=lambda x: order[x.function.uri])
        extracted_comp = building_blocks.Struct(intrinsics)
    else:
        extracted_comp = intrinsics[0]
    ref_name = next(name_generator)
    ref_type = computation_types.to_type(extracted_comp.type_signature)
    ref = building_blocks.Reference(ref_name, ref_type)

    def _should_transform(comp):
        return building_block_analysis.is_called_intrinsic(comp, uri)

    def _transform(comp):
        if not _should_transform(comp):
            return comp, False
        if len(intrinsics) > 1:
            index = intrinsics.index(comp)
            comp = building_blocks.Selection(ref, index=index)
            return comp, True
        else:
            return ref, True

    comp, _ = transformation_utils.transform_postorder(comp, _transform)
    comp = _insert_comp_in_top_level_lambda(comp,
                                            name=ref.name,
                                            comp_to_insert=extracted_comp)
    return comp, True
Example #23
0
 def test_ok_on_single_block(self):
     x_data = building_blocks.Data('x', tf.int32)
     single_block = building_blocks.Block([('x', x_data)], x_data)
     tree_analysis.check_has_unique_names(single_block)
Example #24
0
 def test_raises_on_sequential_binding_of_same_variable_in_block(self):
     x_data = building_blocks.Data('x', tf.int32)
     block = building_blocks.Block([('x', x_data), ('x', x_data)], x_data)
     with self.assertRaises(tree_analysis.NonuniqueNameError):
         tree_analysis.check_has_unique_names(block)
Example #25
0
def remove_duplicate_called_graphs(comp):
  """Deduplicates called graphs for a subset of TFF AST constructs.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` whose called
      graphs we wish to deduplicate, according to `tree_analysis.trees_equal`.
      For `comp` to be eligible here, it must be either a lambda itself whose
      body contains no lambdas or blocks, or another computation containing no
      lambdas or blocks. This restriction is necessary because
      `remove_duplicate_called_graphs` makes no effort to ensure that it is not
      pulling references out of their defining scope, except for the case where
      `comp` is a lambda itself. This function exits early and logs a warning if
      this assumption is violated. Additionally, `comp` must contain only
      computations which can be represented in TensorFlow, IE, satisfy the type
      restriction in `type_analysis.is_tensorflow_compatible_type`.

  Returns:
    Either a called instance of `building_blocks.CompiledComputation` or a
    `building_blocks.CompiledComputation` itself, depending on whether `comp`
    is of non-functional or functional type respectively. Additionally, returns
    a boolean to match the `transformation_utils.TransformSpec` pattern.
  """
  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
  tree_analysis.check_has_unique_names(comp)
  name_generator = building_block_factory.unique_name_generator(comp)
  if comp.is_lambda():
    comp_to_check = comp.result
  else:
    comp_to_check = comp
  if tree_analysis.contains_types(comp_to_check, (
      building_blocks.Block,
      building_blocks.Lambda,
  )):
    logging.warning(
        'The preprocessors have failed to remove called lambdas '
        'and blocks; falling back to less efficient, but '
        'guaranteed, TensorFlow generation with computation %s.', comp)
    return comp, False

  leaf_called_graphs = []

  def _pack_called_graphs_into_block(inner_comp):
    """Packs deduplicated bindings to called graphs in `leaf_called_graphs`."""
    if inner_comp.is_call() and inner_comp.function.is_compiled_computation():
      for (name, x) in leaf_called_graphs:
        if tree_analysis.trees_equal(x, inner_comp):
          return building_blocks.Reference(name,
                                           inner_comp.type_signature), True
      new_name = next(name_generator)
      leaf_called_graphs.append((new_name, inner_comp))
      return building_blocks.Reference(new_name,
                                       inner_comp.type_signature), True

    return inner_comp, False

  if comp.is_lambda():
    transformed_result, _ = transformation_utils.transform_postorder(
        comp.result, _pack_called_graphs_into_block)
    packed_into_block = building_blocks.Block(leaf_called_graphs,
                                              transformed_result)
    parsed, _ = create_tensorflow_representing_block(packed_into_block)
    tff_func = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                      parsed)
    tf_parser_callable = tree_to_cc_transformations.TFParser()
    comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(tff_func)
    tf_generated, _ = transformation_utils.transform_postorder(
        comp, tf_parser_callable)
  else:
    transformed_result, _ = transformation_utils.transform_postorder(
        comp, _pack_called_graphs_into_block)
    packed_into_block = building_blocks.Block(leaf_called_graphs,
                                              transformed_result)
    tf_generated, _ = create_tensorflow_representing_block(packed_into_block)
  return tf_generated, True
Example #26
0
 def test_ok_block_binding_of_new_variable(self):
     x_data = building_blocks.Data('x', tf.int32)
     single_block = building_blocks.Block([('x', x_data)], x_data)
     lambda_1 = building_blocks.Lambda('y', tf.int32, single_block)
     tree_analysis.check_has_unique_names(lambda_1)
Example #27
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         tree_analysis.check_has_unique_names(None)
Example #28
0
 def test_ok_lambda_binding_of_new_variable(self):
     y_ref = building_blocks.Reference('y', tf.int32)
     lambda_1 = building_blocks.Lambda('y', tf.int32, y_ref)
     x_data = building_blocks.Data('x', tf.int32)
     single_block = building_blocks.Block([('x', x_data)], lambda_1)
     tree_analysis.check_has_unique_names(single_block)
Example #29
0
def force_align_and_split_by_intrinsics(
    comp: building_blocks.Lambda,
    intrinsic_defaults: List[building_blocks.Call],
) -> Tuple[building_blocks.Lambda, building_blocks.Lambda]:
    """Divides `comp` into before-and-after of calls to one ore more intrinsics.

  The input computation `comp` must have the following properties:

  1. The computation `comp` is completely self-contained, i.e., there are no
     references to arguments introduced in a scope external to `comp`.

  2. `comp`'s return value must not contain uncalled lambdas.

  3. None of the calls to intrinsics in `intrinsic_defaults` may be
     within a lambda passed to another external function (intrinsic or
     compiled computation).

  4. No argument passed to an intrinsic in `intrinsic_defaults` may be
     dependent on the result of a call to an intrinsic in
     `intrinsic_uris_and_defaults`.

  5. All intrinsics in `intrinsic_defaults` must have "merge-able" arguments.
     Structs will be merged element-wise, federated values will be zipped, and
     functions will be composed:
       `f = lambda f1_arg, f2_arg: (f1(f1_arg), f2(f2_arg))`

  6. All intrinsics in `intrinsic_defaults` must return a single federated value
     whose member is the merged result of any merged calls, i.e.:
       `f(merged_arg).member = (f1(f1_arg).member, f2(f2_arg).member)`

  Under these conditions, (and assuming `comp` is a computation with non-`None`
  argument), this function will return two `building_blocks.Lambda`s `before`
  and `after` such that `comp` is semantically equivalent to the following
  expression*:

  ```
  (arg -> (let
    x=before(arg),
    y=intrinsic1(x[0]),
    z=intrinsic2(x[1]),
    ...
   in after(<arg, <y,z,...>>)))
  ```

  If `comp` is a no-arg computation, the returned computations will be
  equivalent (in the same sense as above) to:
  ```
  ( -> (let
    x=before(),
    y=intrinsic1(x[0]),
    z=intrinsic2(x[1]),
    ...
   in after(<y,z,...>)))
  ```

  *Note that these expressions may not be entirely equivalent under
  nondeterminism since there is no way in this case to handle computations in
  which `before` creates a random variable that is then used in `after`, since
  the only way for state to pass from `before` to `after` is for it to travel
  through one of the intrinsics.

  In this expression, there is only a single call to `intrinsic` that results
  from consolidating all occurrences of this intrinsic in the original `comp`.
  All logic in `comp` that produced inputs to any these intrinsic calls is now
  consolidated and jointly encapsulated in `before`, which produces a combined
  argument to all the original calls. All the remaining logic in `comp`,
  including that which consumed the outputs of the intrinsic calls, must have
  been encapsulated into `after`.

  If the original computation `comp` had type `(T -> U)`, then `before` and
  `after` would be `(T -> X)` and `(<T,Y> -> U)`, respectively, where `X` is
  the type of the argument to the single combined intrinsic call above. Note
  that `after` takes the output of the call to the intrinsic as well as the
  original argument to `comp`, as it may be dependent on both.

  Args:
    comp: The instance of `building_blocks.Lambda` that serves as the input to
      this transformation, as described above.
    intrinsic_defaults: A list of intrinsics with which to split the
      computation, provided as a list of `Call`s to insert if no intrinsic with
      a matching URI is found. Intrinsics in this list will be merged, and
      `comp` will be split across them.

  Returns:
    A pair of the form `(before, after)`, where each of `before` and `after`
    is a `building_blocks.ComputationBuildingBlock` instance that represents a
    part of the result as specified above.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(intrinsic_defaults, list)
    comp_repr = comp.compact_representation()

    # Flatten `comp` to call-dominant form so that we're working with just a
    # linear list of intrinsic calls with no indirection via tupling, selection,
    # blocks, called lambdas, or references.
    comp = to_call_dominant(comp)

    # CDF can potentially return blocks if there are variables not dependent on
    # the top-level parameter. We normalize these away.
    if not comp.is_lambda():
        comp.check_block()
        comp.result.check_lambda()
        if comp.result.result.is_block():
            additional_locals = comp.result.result.locals
            result = comp.result.result.result
        else:
            additional_locals = []
            result = comp.result.result
        # Note: without uniqueness, a local in `comp.locals` could potentially
        # shadow `comp.result.parameter_name`. However, `to_call_dominant`
        # above ensure that names are unique, as it ends in a call to
        # `uniquify_reference_names`.
        comp = building_blocks.Lambda(
            comp.result.parameter_name, comp.result.parameter_type,
            building_blocks.Block(comp.locals + additional_locals, result))
    comp.check_lambda()

    # Simple computations with no intrinsic calls won't have a block.
    # Normalize these as well.
    if not comp.result.is_block():
        comp = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                      building_blocks.Block([], comp.result))
    comp.result.check_block()

    name_generator = building_block_factory.unique_name_generator(comp)

    intrinsic_uris = set(call.function.uri for call in intrinsic_defaults)
    deps = _compute_intrinsic_dependencies(intrinsic_uris, comp.parameter_name,
                                           comp.result.locals, comp_repr)
    merged_intrinsics = _compute_merged_intrinsics(intrinsic_defaults,
                                                   deps.uri_to_locals,
                                                   name_generator)

    # Note: the outputs are labeled as `{uri}_param for convenience, e.g.
    # `federated_secure_sum_param: ...`.
    before = building_blocks.Lambda(
        comp.parameter_name, comp.parameter_type,
        building_blocks.Block(
            deps.locals_not_dependent_on_intrinsics,
            building_blocks.Struct([(f'{merged.uri}_param', merged.args)
                                    for merged in merged_intrinsics])))

    after_param_name = next(name_generator)
    if comp.parameter_type is not None:
        # TODO(b/147499373): If None-arguments were uniformly represented as empty
        # tuples, we would be able to avoid this (and related) ugly casing.
        after_param_type = computation_types.StructType([
            ('original_arg', comp.parameter_type),
            ('intrinsic_results',
             computation_types.StructType([(f'{merged.uri}_result',
                                            merged.return_type)
                                           for merged in merged_intrinsics])),
        ])
    else:
        after_param_type = computation_types.StructType([
            ('intrinsic_results',
             computation_types.StructType([(f'{merged.uri}_result',
                                            merged.return_type)
                                           for merged in merged_intrinsics])),
        ])
    after_param_ref = building_blocks.Reference(after_param_name,
                                                after_param_type)
    if comp.parameter_type is not None:
        original_arg_bindings = [
            (comp.parameter_name,
             building_blocks.Selection(after_param_ref, name='original_arg'))
        ]
    else:
        original_arg_bindings = []

    unzip_bindings = []
    for merged in merged_intrinsics:
        if merged.unpack_to_locals:
            intrinsic_result = building_blocks.Selection(
                building_blocks.Selection(after_param_ref,
                                          name='intrinsic_results'),
                name=f'{merged.uri}_result')
            select_param_type = intrinsic_result.type_signature.member
            for i, binding_name in enumerate(merged.unpack_to_locals):
                select_param_name = next(name_generator)
                select_param_ref = building_blocks.Reference(
                    select_param_name, select_param_type)
                selected = building_block_factory.create_federated_map_or_apply(
                    building_blocks.Lambda(
                        select_param_name, select_param_type,
                        building_blocks.Selection(select_param_ref, index=i)),
                    intrinsic_result)
                unzip_bindings.append((binding_name, selected))
    after = building_blocks.Lambda(
        after_param_name,
        after_param_type,
        building_blocks.Block(
            original_arg_bindings +
            # Note that we must duplicate `locals_not_dependent_on_intrinsics`
            # across both the `before` and `after` computations since both can
            # rely on them, and there's no way to plumb results from `before`
            # through to `after` except via one of the intrinsics being split
            # upon. In MapReduceForm, this limitation is caused by the fact that
            # `prepare` has no output which serves as an input to `report`.
            deps.locals_not_dependent_on_intrinsics + unzip_bindings +
            deps.locals_dependent_on_intrinsics,
            comp.result.result))
    try:
        tree_analysis.check_has_unique_names(before)
        tree_analysis.check_has_unique_names(after)
    except tree_analysis.NonuniqueNameError as e:
        raise ValueError(
            f'nonunique names in result of splitting\n{comp}') from e
    return before, after