コード例 #1
0
def count_tensorflow_variables_under(comp):
    """Counts total TF variables in any TensorFlow computations under `comp`.

  Notice that this function is designed for the purpose of instrumentation,
  in particular to check the size and constituents of the TensorFlow
  artifacts generated.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` whose TF
      variables we wish to count.

  Returns:
    `integer` count of number of TF variables present in any
    `building_blocks.CompiledComputation` of the TensorFlow
    variety under `comp`.
  """
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    # TODO(b/129791812): Cleanup Python 2 and 3 compatibility
    total_tf_vars = [0]

    def _count_tf_vars(inner_comp):
        if (isinstance(inner_comp, building_blocks.CompiledComputation) and
                inner_comp.proto.WhichOneof('computation') == 'tensorflow'):
            total_tf_vars[
                0] += building_block_analysis.count_tensorflow_variables_in(
                    inner_comp)
        return inner_comp, False

    transformation_utils.transform_postorder(comp, _count_tf_vars)
    return total_tf_vars[0]
コード例 #2
0
def check_intrinsics_whitelisted_for_reduction(comp):
    """Checks whitelist of intrinsics reducible to aggregate or broadcast.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` to check for
      presence of intrinsics not currently immediately reducible to
      `FEDERATED_AGGREGATE` or `FEDERATED_BROADCAST`, or local processing.

  Raises:
    ValueError: If we encounter an intrinsic under `comp` that is not
    whitelisted as currently reducible.
  """
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    uri_whitelist = (
        intrinsic_defs.FEDERATED_AGGREGATE.uri,
        intrinsic_defs.FEDERATED_APPLY.uri,
        intrinsic_defs.FEDERATED_BROADCAST.uri,
        intrinsic_defs.FEDERATED_MAP.uri,
        intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri,
        intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri,
        intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri,
        intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri,
        intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri,
    )

    def _check_whitelisted(comp):
        if (isinstance(comp, building_blocks.Intrinsic)
                and comp.uri not in uri_whitelist):
            raise ValueError(
                'Encountered an Intrinsic not currently reducible to aggregate or '
                'broadcast, the intrinsic {}'.format(
                    comp.compact_representation()))
        return comp, False

    transformation_utils.transform_postorder(comp, _check_whitelisted)
コード例 #3
0
    def test_compile_computation(self):
        @computations.federated_computation([
            computation_types.FederatedType(tf.float32, placements.CLIENTS),
            computation_types.FederatedType(tf.float32, placements.SERVER,
                                            True)
        ])
        def foo(temperatures, threshold):
            return intrinsics.federated_sum(
                intrinsics.federated_map(
                    computations.tf_computation(
                        lambda x, y: tf.cast(tf.greater(x, y), tf.int32),
                        [tf.float32, tf.float32]),
                    [temperatures,
                     intrinsics.federated_broadcast(threshold)]))

        pipeline = compiler_pipeline.CompilerPipeline(
            context_stack_impl.context_stack)

        compiled_foo = pipeline.compile(foo)

        def _not_federated_sum(x):
            if isinstance(x, building_blocks.Intrinsic):
                self.assertNotEqual(x.uri, intrinsic_defs.FEDERATED_SUM.uri)
            return x, False

        transformation_utils.transform_postorder(
            building_blocks.ComputationBuildingBlock.from_proto(
                computation_impl.ComputationImpl.get_proto(compiled_foo)),
            _not_federated_sum)
コード例 #4
0
    def _check_no_functional_symbol_bindings(comp):
        """Encodes condition for completeness of direct extraction of calls.

    After checking this condition, all functions which are semantically called
    (IE, functions which will be invoked eventually by running the computation)
    are called directly, and we can simply extract them by pattern-matching on
    `building_blocks.Call`.

    Args:
      comp: Instance of `building_blocks.ComputationBuildingBlock` to check for
        lack of functional symbol bindings.

    Raises:
      ValueError: If `comp` has symbols bound to computations with type trees
      containing functional types.
    """
        py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)

        def _check_for_bindings(comp_to_check):
            if comp_to_check.is_block():
                for name, local in comp_to_check.locals:
                    if type_analysis.contains(local.type_signature,
                                              lambda x: x.is_function()):
                        raise ValueError(
                            'We make the assumption when reducing to '
                            'call-dominant form that there are no symbols bound '
                            'to computations with functional type; encountered '
                            'the computation {c} of type {t} bound to symbol '
                            '{s}. Failure here indicates an internal error in '
                            'the construction of call-dominant form.'.format(
                                c=local, t=local.type_signature, s=name))
            return comp, False

        transformation_utils.transform_postorder(comp, _check_for_bindings)
コード例 #5
0
def _get_unbound_ref(block):
  """Helper to get unbound ref name and type spec if it exists in `block`."""
  all_unbound_refs = transformation_utils.get_map_of_unbound_references(block)
  top_level_unbound_ref = all_unbound_refs[block]
  num_unbound_refs = len(top_level_unbound_ref)
  if num_unbound_refs == 0:
    return None
  elif num_unbound_refs > 1:
    raise ValueError('`create_tensorflow_representing_block` must be passed '
                     'a block with at most a single unbound reference; '
                     'encountered the block {} with {} unbound '
                     'references.'.format(block, len(top_level_unbound_ref)))

  unbound_ref_name = top_level_unbound_ref.pop()

  top_level_type_spec = None

  def _get_unbound_ref_type_spec(inner_comp):
    if (inner_comp.is_reference() and inner_comp.name == unbound_ref_name):
      nonlocal top_level_type_spec
      top_level_type_spec = inner_comp.type_signature
    return inner_comp, False

  transformation_utils.transform_postorder(block, _get_unbound_ref_type_spec)
  return building_blocks.Reference(unbound_ref_name, top_level_type_spec)
コード例 #6
0
def check_has_single_placement(comp, single_placement):
    """Checks that the AST of `comp` contains only `single_placement`.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock`.
    single_placement: Instance of `placement_literals.PlacementLiteral` which
      should be the only placement present under `comp`.

  Raises:
    ValueError: If the AST under `comp` contains any
    `computation_types.FederatedType` other than `single_placement`.
  """
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(single_placement,
                            placement_literals.PlacementLiteral)

    def _check_single_placement(comp):
        """Checks that the placement in `type_spec` matches `single_placement`."""
        if (isinstance(comp.type_signature, computation_types.FederatedType)
                and comp.type_signature.placement != single_placement):
            raise ValueError(
                'Comp contains a placement other than {}; '
                'placement {} on comp {} inside the structure. '.format(
                    single_placement, comp.type_signature.placement,
                    comp.compact_representation()))
        return comp, False

    transformation_utils.transform_postorder(comp, _check_single_placement)
コード例 #7
0
def dedupe_and_merge_tuple_intrinsics(comp, uri):
  r"""Merges tuples of called intrinsics into one called intrinsic."""

  # TODO(b/147359721): The application of the function below is a workaround to
  # a known pattern preventing TFF from deduplicating, effectively because tree
  # equality won't determine that [a, a][0] and [a, a][1] are actually the same
  # thing. A fuller fix is planned, but requires increasing the invariants
  # respected further up the TFF compilation pipelines. That is, in order to
  # reason about sufficiency of our ability to detect duplicates at this layer,
  # we would very much prefer to be operating in the subset of TFF effectively
  # representing local computation.

  def _remove_selection_from_block_holding_tuple(comp):
    """Reduces selection from a block holding a tuple."""
    if (comp.is_selection() and comp.source.is_block() and
        comp.source.result.is_struct()):
      if comp.index is None:
        names = [
            x[0]
            for x in anonymous_tuple.iter_elements(comp.source.type_signature)
        ]
        index = names.index(comp.name)
      else:
        index = comp.index
      return building_blocks.Block(comp.source.locals,
                                   comp.source.result[index]), True
    return comp, False

  comp, _ = transformation_utils.transform_postorder(
      comp, _remove_selection_from_block_holding_tuple)
  transform_spec = tree_transformations.MergeTupleIntrinsics(comp, uri)
  dedupe_and_merger = RemoveDuplicatesAndApplyTransform(comp, transform_spec)
  return transformation_utils.transform_postorder(comp,
                                                  dedupe_and_merger.transform)
コード例 #8
0
ファイル: transformations.py プロジェクト: uu0316/federated
def _check_parameters_for_tf_block_generation(block):
  """Helper to validate parameters for parsing block locals into TF graphs."""
  py_typecheck.check_type(block, building_blocks.Block)
  for _, comp in block.locals:
    if not (isinstance(comp, building_blocks.Call) and
            isinstance(comp.function, building_blocks.CompiledComputation)):
      raise ValueError(
          'create_tensorflow_representing_block may only be called '
          'on a block whose local variables are all bound to '
          'called TensorFlow computations; encountered a local '
          'bound to {}'.format(comp))

  def _check_contains_only_refs_sels_and_tuples(inner_comp):
    if not isinstance(inner_comp,
                      (building_blocks.Reference, building_blocks.Selection,
                       building_blocks.Tuple)):
      raise ValueError(
          'create_tensorflow_representing_block may only be called '
          'on a block whose result contains only Selections, '
          'Tuples and References; encountered the building block '
          '{}.'.format(inner_comp))
    return inner_comp, False

  transformation_utils.transform_postorder(
      block.result, _check_contains_only_refs_sels_and_tuples)
コード例 #9
0
ファイル: tree_analysis.py プロジェクト: dingfan21/federated
def _visit_postorder(comp, function):
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)

    def _function(inner_comp):
        function(inner_comp)
        return inner_comp, False

    transformation_utils.transform_postorder(comp, _function)
コード例 #10
0
    def get_uri_for_all_called_intrinsics(comp):
        existing_uri = set()

        def _update(comp):
            if building_block_analysis.is_called_intrinsic(comp, uri):
                existing_uri.add(comp.function.uri)
            return comp, False

        transformation_utils.transform_postorder(comp, _update)
        return existing_uri
コード例 #11
0
def _visit_postorder(
        tree: building_blocks.ComputationBuildingBlock,
        function: Callable[[building_blocks.ComputationBuildingBlock], None]):
    py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock)

    def _visit(building_block):
        function(building_block)
        return building_block, False

    transformation_utils.transform_postorder(tree, _visit)
コード例 #12
0
    def assertNoLambdasOrBlocks(self, comp):
        def _transform(comp):
            if (isinstance(comp, building_blocks.Call) and isinstance(
                    comp.function, building_blocks.Lambda)) or isinstance(
                        comp, building_blocks.Block):
                raise AssertionError(
                    'Encountered disallowed computation: {}'.format(
                        comp.compact_representation()))
            return comp, True

        transformation_utils.transform_postorder(comp, _transform)
コード例 #13
0
  def _inline_functions(comp):
    function_type_reference_names = []

    def _populate_function_type_ref_names(comp):
      if comp.is_reference() and comp.type_signature.is_function():
        function_type_reference_names.append(comp.name)
      return comp, False

    transformation_utils.transform_postorder(comp,
                                             _populate_function_type_ref_names)

    return tree_transformations.inline_block_locals(
        comp, variable_names=set(function_type_reference_names))
コード例 #14
0
    def test_aggregate_with_selection_from_block_by_name_results_in_single_aggregate(
            self):
        data = building_blocks.Reference(
            'a',
            computation_types.FederatedType(tf.int32,
                                            placement_literals.CLIENTS))
        tup_of_data = building_blocks.Tuple([('a', data), ('b', data)])
        block_holding_tup = building_blocks.Block([], tup_of_data)
        index_0_from_block = building_blocks.Selection(
            source=block_holding_tup, name='a')
        index_1_from_block = building_blocks.Selection(
            source=block_holding_tup, name='b')

        result = building_blocks.Data('aggregation_result', tf.int32)
        zero = building_blocks.Data('zero', tf.int32)
        accumulate = building_blocks.Lambda('accumulate_param',
                                            [tf.int32, tf.int32], result)
        merge = building_blocks.Lambda('merge_param', [tf.int32, tf.int32],
                                       result)
        report = building_blocks.Lambda('report_param', tf.int32, result)

        called_intrinsic0 = building_block_factory.create_federated_aggregate(
            index_0_from_block, zero, accumulate, merge, report)
        called_intrinsic1 = building_block_factory.create_federated_aggregate(
            index_1_from_block, zero, accumulate, merge, report)
        calls = building_blocks.Tuple((called_intrinsic0, called_intrinsic1))
        comp = calls

        deduped_and_merged_comp, deduped_modified = transformations.dedupe_and_merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)

        self.assertTrue(deduped_modified)

        fed_agg = []

        def _find_called_federated_aggregate(comp):
            if (isinstance(comp, building_blocks.Call)
                    and isinstance(comp.function, building_blocks.Intrinsic)
                    and comp.function.uri
                    == intrinsic_defs.FEDERATED_AGGREGATE.uri):
                fed_agg.append(comp.function)
            return comp, False

        transformation_utils.transform_postorder(
            deduped_and_merged_comp, _find_called_federated_aggregate)
        self.assertLen(fed_agg, 1)
        self.assertEqual(
            fed_agg[0].type_signature.parameter[0].compact_representation(),
            '{<int32>}@CLIENTS')
コード例 #15
0
    def test_ops_not_duplicated_in_resulting_tensorflow(self):
        def _construct_block_and_inlined_tuple(k):
            concrete_int_type = computation_types.TensorType(tf.int32)
            concrete_int = building_block_factory.create_tensorflow_constant(
                concrete_int_type, 1)
            first_tf_id_type = computation_types.TensorType(tf.int32)
            first_tf_id = building_block_factory.create_compiled_identity(
                first_tf_id_type)
            called_tf_id = building_blocks.Call(first_tf_id, concrete_int)
            for _ in range(k):
                # Simulating large TF computation
                called_tf_id = building_blocks.Call(first_tf_id, called_tf_id)
            ref_to_call = building_blocks.Reference(
                'call', called_tf_id.type_signature)
            block_locals = [('call', called_tf_id)]
            block = building_blocks.Block(
                block_locals, building_blocks.Tuple([ref_to_call,
                                                     ref_to_call]))
            inlined_tuple = building_blocks.Tuple([called_tf_id, called_tf_id])
            return block, inlined_tuple

        block_with_5_ids, inlined_tuple_with_5_ids = _construct_block_and_inlined_tuple(
            5)
        block_with_10_ids, inlined_tuple_with_10_ids = _construct_block_and_inlined_tuple(
            10)
        tf_representing_block_with_5_ids, _ = transformations.create_tensorflow_representing_block(
            block_with_5_ids)
        tf_representing_block_with_10_ids, _ = transformations.create_tensorflow_representing_block(
            block_with_10_ids)
        block_ops_with_5_ids = tree_analysis.count_tensorflow_ops_under(
            tf_representing_block_with_5_ids)
        block_ops_with_10_ids = tree_analysis.count_tensorflow_ops_under(
            tf_representing_block_with_10_ids)

        parser_callable = tree_to_cc_transformations.TFParser()
        naively_generated_tf_with_5_ids, _ = transformation_utils.transform_postorder(
            inlined_tuple_with_5_ids, parser_callable)
        naively_generated_tf_with_10_ids, _ = transformation_utils.transform_postorder(
            inlined_tuple_with_10_ids, parser_callable)
        tuple_ops_with_5_ids = tree_analysis.count_tensorflow_ops_under(
            naively_generated_tf_with_5_ids)
        tuple_ops_with_10_ids = tree_analysis.count_tensorflow_ops_under(
            naively_generated_tf_with_10_ids)

        # asserting that block ops are linear in k with slope 1.
        self.assertEqual((block_ops_with_10_ids - block_ops_with_5_ids) / 5, 1)
        # asserting that tuple ops are linear in k with slope 2.
        self.assertEqual((tuple_ops_with_10_ids - tuple_ops_with_5_ids) / 5, 2)
コード例 #16
0
 def _generate_simple_tensorflow(comp):
     tf_parser_callable = tree_to_cc_transformations.TFParser()
     comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(
         comp)
     comp, _ = transformation_utils.transform_postorder(
         comp, tf_parser_callable)
     return comp
コード例 #17
0
def _replace_selections(
    bb: building_blocks.ComputationBuildingBlock,
    ref_name: str,
    path_to_replacement: Dict[Tuple[int, ...],
                              building_blocks.ComputationBuildingBlock],
) -> building_blocks.ComputationBuildingBlock:
    """Identifies selection pattern and replaces with new binding.

  Note that this function is somewhat brittle in that it only replaces AST
  fragments of exactly the form `ref_name[i][j][k]` (for path `(i, j, k)`).
  That is, it will not detect `let x = ref_name[i][j] in x[k]` or similar.

  This is only sufficient because, at the point this function has been called,
  called lambdas have been replaced with blocks and blocks have been inlined,
  so there are no reference chains that must be traced back. Any reference which
  would eventually resolve to a part of a lambda's parameter instead refers to
  the parameter directly. Similarly, selections from tuples have been collapsed.
  The remaining concern would be selections via calls to opaque compiled
  compuations, which we error on.

  Args:
    bb: Instance of `building_blocks.ComputationBuildingBlock` in which we wish
      to replace the selections from reference `ref_name` with any path in
      `paths_to_replacement` with the corresponding building block.
    ref_name: Name of the reference to look for selectiosn from.
    path_to_replacement: A map from selection path to the building block with
      which to replace the selection. Note; it is not valid to specify
      overlapping selection paths (where one path encompasses another).

  Returns:
    A possibly transformed version of `bb` with nodes matching the
    selection patterns replaced.
  """
    def _replace(inner_bb):
        # Start with an empty selection
        path = []
        selection = inner_bb
        while selection.is_selection():
            path.append(selection.as_index())
            selection = selection.source
        # In ASTs like x[0][1], we'll see the last (outermost) selection first.
        path.reverse()
        path = tuple(path)
        if (selection.is_reference() and selection.name == ref_name
                and path in path_to_replacement):
            return path_to_replacement[path], True
        if (inner_bb.is_call() and inner_bb.function.is_compiled_computation()
                and inner_bb.argument is not None
                and inner_bb.argument.is_reference()
                and inner_bb.argument.name == ref_name):
            raise ValueError(
                'Encountered called graph on reference pattern in TFF '
                'AST; this means relying on pattern-matching when '
                'rebinding arguments may be insufficient. Ensure that '
                'arguments are rebound before decorating references '
                'with called identity graphs.')
        return inner_bb, False

    result, _ = transformation_utils.transform_postorder(bb, _replace)
    return result
コード例 #18
0
def _generate_simple_tensorflow(comp):
    """Naively generates TensorFlow to represent `comp`."""
    tf_parser_callable = tree_to_cc_transformations.TFParser()
    comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(comp)
    comp, _ = transformation_utils.transform_postorder(comp,
                                                       tf_parser_callable)
    return comp
コード例 #19
0
    def test_constructs_broadcast_of_tuple_with_one_element(self):
        called_intrinsic = test_utils.create_dummy_called_federated_broadcast()
        calls = building_blocks.Tuple((called_intrinsic, called_intrinsic))
        comp = calls

        transformed_comp, modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_BROADCAST.uri)

        federated_broadcast = []

        def _find_federated_broadcast(comp):
            if building_block_analysis.is_called_intrinsic(
                    comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
                federated_broadcast.append(comp)
            return comp, False

        transformation_utils.transform_postorder(transformed_comp,
                                                 _find_federated_broadcast)

        self.assertTrue(modified)
        self.assertEqual(
            comp.compact_representation(),
            '<federated_broadcast(data),federated_broadcast(data)>')

        self.assertLen(federated_broadcast, 1)
        self.assertLen(federated_broadcast[0].type_signature.member, 1)
        self.assertEqual(
            transformed_comp.formatted_representation(), '(_var1 -> <\n'
            '  _var1[0],\n'
            '  _var1[0]\n'
            '>)((x -> <\n'
            '  x[0]\n'
            '>)((let\n'
            '  value=federated_broadcast(federated_apply(<\n'
            '    (arg -> <\n'
            '      arg\n'
            '    >),\n'
            '    <\n'
            '      data\n'
            '    >[0]\n'
            '  >))\n'
            ' in <\n'
            '  federated_map_all_equal(<\n'
            '    (arg -> arg[0]),\n'
            '    value\n'
            '  >)\n'
            '>)))')
コード例 #20
0
  def _extract_calls_and_blocks(comp):

    def _predicate(comp):
      return comp.is_call()

    block_extracter = tree_transformations.ExtractComputation(comp, _predicate)
    return transformation_utils.transform_postorder(comp,
                                                    block_extracter.transform)
コード例 #21
0
def check_allowed_ops(
    comp: building_blocks.ComputationBuildingBlock,
    allowed_op_names: FrozenSet[str]
) -> Tuple[building_blocks.ComputationBuildingBlock, bool]:
  """Checks any Tensorflow computation contains only allowed ops."""
  transform_spec = VerifyAllowedOps(allowed_op_names)
  return transformation_utils.transform_postorder(comp,
                                                  transform_spec.transform)
コード例 #22
0
def check_disallowed_ops(
    comp: building_blocks.ComputationBuildingBlock,
    disallowed_op_names: FrozenSet[str]
) -> Tuple[building_blocks.ComputationBuildingBlock, bool]:
  """Raises error on disallowed ops in any Tensorflow computation."""
  transform_spec = RaiseOnDisallowedOp(disallowed_op_names)
  return transformation_utils.transform_postorder(comp,
                                                  transform_spec.transform)
コード例 #23
0
 def test_unwraps_block_with_empty_locals(self):
     input_data = building_blocks.Data('b', tf.int32)
     blk = building_blocks.Block([], input_data)
     data, modified = transformation_utils.transform_postorder(
         blk, self._unused_block_remover.transform)
     self.assertTrue(modified)
     self.assertEqual(data.compact_representation(),
                      input_data.compact_representation())
コード例 #24
0
ファイル: compiler.py プロジェクト: tensorflow/federated
def concatenate_function_outputs(first_function, second_function):
    """Constructs a new function concatenating the outputs of its arguments.

  Assumes that `first_function` and `second_function` already have unique
  names, and have declared parameters of the same type. The constructed
  function will bind its parameter to each of the parameters of
  `first_function` and `second_function`, and return the result of executing
  these functions in parallel and concatenating the outputs in a tuple.

  Args:
    first_function: Instance of `building_blocks.Lambda` whose result we wish to
      concatenate with the result of `second_function`.
    second_function: Instance of `building_blocks.Lambda` whose result we wish
      to concatenate with the result of `first_function`.

  Returns:
    A new instance of `building_blocks.Lambda` with unique names representing
    the computation described above.

  Raises:
    TypeError: If the arguments are not instances of `building_blocks.Lambda`,
    or declare parameters of different types.
  """

    py_typecheck.check_type(first_function, building_blocks.Lambda)
    py_typecheck.check_type(second_function, building_blocks.Lambda)
    tree_analysis.check_has_unique_names(first_function)
    tree_analysis.check_has_unique_names(second_function)

    if first_function.parameter_type != second_function.parameter_type:
        raise TypeError(
            'Must pass two functions which declare the same parameter '
            'type to `concatenate_function_outputs`; you have passed '
            'one function which declared a parameter of type {}, and '
            'another which declares a parameter of type {}'.format(
                first_function.type_signature, second_function.type_signature))

    def _rename_first_function_arg(comp):
        if comp.is_reference() and comp.name == first_function.parameter_name:
            if comp.type_signature != second_function.parameter_type:
                raise AssertionError('{}, {}'.format(
                    comp.type_signature, second_function.parameter_type))
            return building_blocks.Reference(second_function.parameter_name,
                                             comp.type_signature), True
        return comp, False

    first_function, _ = transformation_utils.transform_postorder(
        first_function, _rename_first_function_arg)

    concatenated_function = building_blocks.Lambda(
        second_function.parameter_name, second_function.parameter_type,
        building_blocks.Struct([first_function.result,
                                second_function.result]))

    renamed, _ = tree_transformations.uniquify_reference_names(
        concatenated_function)

    return renamed
コード例 #25
0
 def test_leaves_single_used_reference(self):
     blk = building_blocks.Block(
         [('x', building_blocks.Data('a', tf.int32))],
         building_blocks.Reference('x', tf.int32))
     transformed_blk, modified = transformation_utils.transform_postorder(
         blk, self._unused_block_remover.transform)
     self.assertFalse(modified)
     self.assertEqual(transformed_blk.compact_representation(),
                      blk.compact_representation())
コード例 #26
0
  def test_parameters_are_mapped_together(self):
    x_reference = building_blocks.Reference('x', tf.int32)
    x_lambda = building_blocks.Lambda('x', tf.int32, x_reference)
    y_reference = building_blocks.Reference('y', tf.int32)
    y_lambda = building_blocks.Lambda('y', tf.int32, y_reference)
    concatenated = transformations.concatenate_function_outputs(
        x_lambda, y_lambda)
    parameter_name = concatenated.parameter_name

    def _raise_on_other_name_reference(comp):
      if isinstance(comp,
                    building_blocks.Reference) and comp.name != parameter_name:
        raise ValueError
      return comp, True

    tree_analysis.check_has_unique_names(concatenated)
    transformation_utils.transform_postorder(concatenated,
                                             _raise_on_other_name_reference)
コード例 #27
0
def parse_tff_to_tf(comp):
  comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(comp)
  parser_callable = tree_to_cc_transformations.TFParser()
  comp, _ = tree_transformations.replace_called_lambda_with_block(comp)
  comp, _ = tree_transformations.inline_block_locals(comp)
  comp, _ = tree_transformations.replace_selection_from_tuple_with_element(comp)
  new_comp, transformed = transformation_utils.transform_postorder(
      comp, parser_callable)
  return new_comp, transformed
コード例 #28
0
 def test_removes_nested_blocks_with_unused_reference(self):
     input_data = building_blocks.Data('b', tf.int32)
     blk = building_blocks.Block(
         [('x', building_blocks.Data('a', tf.int32))], input_data)
     higher_level_blk = building_blocks.Block([('y', input_data)], blk)
     data, modified = transformation_utils.transform_postorder(
         higher_level_blk, self._unused_block_remover.transform)
     self.assertTrue(modified)
     self.assertEqual(data.compact_representation(),
                      input_data.compact_representation())
コード例 #29
0
 def test_leaves_lone_referenced_local(self):
     ref = building_blocks.Reference('y', tf.int32)
     blk = building_blocks.Block(
         [('x', building_blocks.Data('a', tf.int32)),
          ('y', building_blocks.Data('b', tf.int32))], ref)
     transformed_blk, modified = transformation_utils.transform_postorder(
         blk, self._unused_block_remover.transform)
     self.assertTrue(modified)
     self.assertEqual(transformed_blk.compact_representation(),
                      '(let y=b in y)')
コード例 #30
0
def count(comp, predicate=None):
    """Returns the number of computations in `comp` matching `predicate`.

  Args:
    comp: The computation to test.
    predicate: An optional Python function that takes a computation as a
      parameter and returns a boolean value. If `None`, all computations are
      counted.
  """
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    counter = [0]

    def _function(comp):
        if predicate is None or predicate(comp):
            counter[0] += 1
        return comp, False

    transformation_utils.transform_postorder(comp, _function)
    return counter[0]