def _insert_comp_in_top_level_lambda(comp, name, comp_to_insert): """Inserts a computation into `comp` with the given `name`. Args: comp: The `tff_framework.Lambda` to transform. The names of lambda parameters and block variables in `comp` must be unique. name: The name to use. comp_to_insert: The `tff_framework.ComputationBuildingBlock` to insert. Returns: A new computation with the transformation applied or the original `comp`. """ py_typecheck.check_type(comp, tff_framework.Lambda) tff_framework.check_has_unique_names(comp) py_typecheck.check_type(name, six.string_types) py_typecheck.check_type(comp_to_insert, tff_framework.ComputationBuildingBlock) result = comp.result if isinstance(result, tff_framework.Block): variables = result.locals result = result.result else: variables = [] variables.insert(0, (name, comp_to_insert)) block = tff_framework.Block(variables, result) return tff_framework.Lambda(comp.parameter_name, comp.parameter_type, block)
def concatenate_function_outputs(first_function, second_function): """Constructs a new function concatenating the outputs of its arguments. Assumes that `first_function` and `second_function` already have unique names, and have declared parameters of the same type. The constructed function will bind its parameter to each of the parameters of `first_function` and `second_function`, and return the result of executing these functions in parallel and concatenating the outputs in a tuple. Args: first_function: Instance of `tff_framework.Lambda` whose result we wish to concatenate with the result of `second_function`. second_function: Instance of `tff_framework.Lambda` whose result we wish to concatenate with the result of `first_function`. Returns: A new instance of `tff_framework.Lambda` with unique names representing the computation described above. Raises: TypeError: If the arguments are not instances of `tff_framework.Lambda`, or declare parameters of different types. """ py_typecheck.check_type(first_function, tff_framework.Lambda) py_typecheck.check_type(second_function, tff_framework.Lambda) tff_framework.check_has_unique_names(first_function) tff_framework.check_has_unique_names(second_function) if first_function.parameter_type != second_function.parameter_type: raise TypeError('Must pass two functions which declare the same parameter ' 'type to `concatenate_function_outputs`; you have passed ' 'one function which declared a parameter of type {}, and ' 'another which declares a parameter of type {}'.format( first_function.type_signature, second_function.type_signature)) def _rename_first_function_arg(comp): if isinstance( comp, tff_framework.Reference) and comp.name == first_function.parameter_name: if comp.type_signature != second_function.parameter_type: raise AssertionError('{}, {}'.format(comp.type_signature, second_function.parameter_type)) return tff_framework.Reference(second_function.parameter_name, comp.type_signature), True return comp, False first_function, _ = tff_framework.transform_postorder( first_function, _rename_first_function_arg) concatenated_function = tff_framework.Lambda( second_function.parameter_name, second_function.parameter_type, tff_framework.Tuple([first_function.result, second_function.result])) renamed, _ = tff_framework.uniquify_reference_names(concatenated_function) return renamed
def bind_single_selection_as_argument_to_lower_level_lambda(comp, index): r"""Binds selection from the param of `comp` as param to lower-level lambda. The returned pattern is quite important here; given an input lambda `comp`, we will return an equivalent structure of the form: Lambda(x) | Call / \ Lambda Selection from x Args: comp: Instance of `tff_framework.Lambda`, whose parameters we wish to rebind to a different lambda. This lambda must have unique names. index: `int` representing the index to bind as an argument to the lower-level lambda. Returns: An instance of `tff_framework.Lambda`, equivalent to `comp`, satisfying the pattern above. """ py_typecheck.check_type(comp, tff_framework.Lambda) py_typecheck.check_type(index, int) tff_framework.check_has_unique_names(comp) comp = _prepare_for_rebinding(comp) name_generator = tff_framework.unique_name_generator(comp) parameter_name = comp.parameter_name new_name = six.next(name_generator) new_ref = tff_framework.Reference(new_name, comp.type_signature.parameter[index]) def _remove_selection_from_ref(inner_comp): if isinstance(inner_comp, tff_framework.Selection) and isinstance( inner_comp.source, tff_framework.Reference ) and inner_comp.index == index and inner_comp.source.name == parameter_name: return new_ref, True return inner_comp, False references_rebound_in_result, _ = tff_framework.transform_postorder( comp.result, _remove_selection_from_ref) newly_bound_lambda = tff_framework.Lambda(new_ref.name, new_ref.type_signature, references_rebound_in_result) _check_for_missed_binding(comp, newly_bound_lambda) original_ref = tff_framework.Reference(comp.parameter_name, comp.parameter_type) selection = tff_framework.Selection(original_ref, index=index) called_rebound = tff_framework.Call(newly_bound_lambda, selection) return tff_framework.Lambda(comp.parameter_name, comp.parameter_type, called_rebound)
def _can_extract_intrinsic_to_top_level_lambda(comp, uri): """Tests if the intrinsic for the given `uri` can be extracted. Args: comp: The `tff_framework.Lambda` to test. The names of lambda parameters and block variables in `comp` must be unique. uri: A URI of an intrinsic. Returns: `True` if the intrinsic can be extracted, otherwise `False`. """ py_typecheck.check_type(comp, tff_framework.Lambda) tff_framework.check_has_unique_names(comp) py_typecheck.check_type(uri, six.string_types) intrinsics = _get_called_intrinsics(comp, uri) return _are_comps_bound_exclusively_by_top_level_lambda(comp, intrinsics)
def _extract_intrinsic_as_reference_to_top_level_lambda(comp, uri): """Extracts an intrinsic from `comp` as a reference for the given `uri`. Args: comp: The `tff_framework.Lambda` to transform. The names of lambda parameters and block variables in `comp` must be unique. uri: A URI of an intrinsic. Returns: A new computation with the transformation applied or the original `comp`. Raises: ValueError: If there is more than one intrinsic for the give `uri` or if the intrinsic is not exclusively bound by `comp`. """ py_typecheck.check_type(comp, tff_framework.Lambda) tff_framework.check_has_unique_names(comp) py_typecheck.check_type(uri, six.string_types) intrinsics = _get_called_intrinsics(comp, uri) length = len(intrinsics) if length != 1: raise ValueError( 'Expected a computation with exactly one intrinsic with the uri: {}, ' 'found: {}.'.format(uri, length)) if not _are_comps_bound_exclusively_by_top_level_lambda(comp, intrinsics): raise ValueError( 'Expected a computation which binds all the references in the ' 'intrinsic with the uri: {}.'.format(uri)) name_generator = tff_framework.unique_name_generator(comp) extracted_intrinsic = intrinsics[0] ref_name = six.next(name_generator) ref_type = tff.to_type(extracted_intrinsic.type_signature) ref = tff_framework.Reference(ref_name, ref_type) def _should_transform(comp): return tff_framework.is_called_intrinsic(comp, uri) def _transform(comp): if not _should_transform(comp): return comp, False return ref, True comp, _ = tff_framework.transform_postorder(comp, _transform) comp = _insert_comp_in_top_level_lambda( comp, name=ref.name, comp_to_insert=extracted_intrinsic) return comp, True
def _are_comps_bound_exclusively_by_top_level_lambda(comp, comps): """Tests if all computations in `comps` are bound exclusively by `comp`. Args: comp: The `tff_framework.Lambda` to test. The names of lambda parameters and block variables in `comp` must be unique. comps: A Python `list` of computations to test. Returns: `True` if the unbound references in each computation in `comps` are bound by exclusively the parameter of `comp`, otherwise `False`. """ py_typecheck.check_type(comp, tff_framework.Lambda) tff_framework.check_has_unique_names(comp) py_typecheck.check_type(comps, (list, tuple, set)) unbound_references = tff_framework.get_map_of_unbound_references(comp) names = set((comp.parameter_name,)) return all(names == unbound_references[e] for e in comps)