예제 #1
0
 def test_nested_lambda_block_overwrite_scope_snapshot(self):
     innermost_x = computation_building_blocks.Reference('x', tf.int32)
     inner_lambda = computation_building_blocks.Lambda(
         'x', tf.int32, innermost_x)
     second_x = computation_building_blocks.Reference('x', tf.int32)
     called_lambda = computation_building_blocks.Call(
         inner_lambda, second_x)
     block_input = computation_building_blocks.Reference(
         'block_in', tf.int32)
     lower_block = computation_building_blocks.Block([('x', block_input)],
                                                     called_lambda)
     second_lambda = computation_building_blocks.Lambda(
         'block_in', tf.int32, lower_block)
     third_x = computation_building_blocks.Reference('x', tf.int32)
     second_call = computation_building_blocks.Call(second_lambda, third_x)
     final_input = computation_building_blocks.Data('test_data', tf.int32)
     last_block = computation_building_blocks.Block([('x', final_input)],
                                                    second_call)
     global_snapshot = transformations.scope_count_snapshot(last_block)
     self.assertEqual(
         str(last_block),
         '(let x=test_data in (block_in -> (let x=block_in in (x -> x)(x)))(x))'
     )
     self.assertLen(global_snapshot, 4)
     self.assertEqual(global_snapshot[str(inner_lambda)], {'x': 1})
     self.assertEqual(global_snapshot[str(lower_block)], {'x': 1})
     self.assertEqual(global_snapshot[str(second_lambda)], {'block_in': 1})
     self.assertEqual(global_snapshot[str(last_block)], {'x': 1})
예제 #2
0
 def test_simple_block_snapshot(self):
     used1 = computation_building_blocks.Reference('used1', tf.int32)
     used2 = computation_building_blocks.Reference('used2', tf.int32)
     ref = computation_building_blocks.Reference('x', used1.type_signature)
     lower_block = computation_building_blocks.Block([('x', used1)], ref)
     higher_block = computation_building_blocks.Block([('used1', used2)],
                                                      lower_block)
     self.assertEqual(str(higher_block),
                      '(let used1=used2 in (let x=used1 in x))')
     snapshot = transformations.scope_count_snapshot(higher_block)
     self.assertEqual(snapshot[str(lower_block)]['x'], 1)
     self.assertEqual(snapshot[str(higher_block)]['used1'], 1)
     self.assertIsNone(snapshot[str(higher_block)].get('x'))
예제 #3
0
 def test_conflicting_name_resolved_inlining(self):
     red_herring_arg = computation_building_blocks.Reference(
         'redherring', tf.int32)
     used_arg = computation_building_blocks.Reference('used', tf.int32)
     ref = computation_building_blocks.Reference('x',
                                                 used_arg.type_signature)
     lower_block = computation_building_blocks.Block([('x', used_arg)], ref)
     higher_block = computation_building_blocks.Block(
         [('x', red_herring_arg)], lower_block)
     self.assertEqual(str(higher_block),
                      '(let x=redherring in (let x=used in x))')
     inlined = transformations.inline_blocks_with_n_referenced_locals(
         higher_block)
     self.assertEqual(str(inlined), '(let  in (let  in used))')
예제 #4
0
    def test_with_block(self):
        ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor())
        loop = asyncio.get_event_loop()

        f_type = computation_types.FunctionType(tf.int32, tf.int32)
        a = computation_building_blocks.Reference(
            'a',
            computation_types.NamedTupleType([('f', f_type), ('x', tf.int32)]))
        ret = computation_building_blocks.Block(
            [('f', computation_building_blocks.Selection(a, name='f')),
             ('x', computation_building_blocks.Selection(a, name='x'))],
            computation_building_blocks.Call(
                computation_building_blocks.Reference('f', f_type),
                computation_building_blocks.Call(
                    computation_building_blocks.Reference('f', f_type),
                    computation_building_blocks.Reference('x', tf.int32))))
        comp = computation_building_blocks.Lambda(a.name, a.type_signature,
                                                  ret)

        @computations.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        v1 = loop.run_until_complete(
            ex.create_value(comp.proto, comp.type_signature))
        v2 = loop.run_until_complete(ex.create_value(add_one))
        v3 = loop.run_until_complete(ex.create_value(10, tf.int32))
        v4 = loop.run_until_complete(
            ex.create_tuple(
                anonymous_tuple.AnonymousTuple([('f', v2), ('x', v3)])))
        v5 = loop.run_until_complete(ex.create_call(v1, v4))
        result = loop.run_until_complete(v5.compute())
        self.assertEqual(result.numpy(), 12)
  def test_execute_with_block(self):
    add_one = computation_building_blocks.ComputationBuildingBlock.from_proto(
        computation_impl.ComputationImpl.get_proto(
            computations.tf_computation(lambda x: x + 1, tf.int32)))

    make_10 = computation_building_blocks.ComputationBuildingBlock.from_proto(
        computation_impl.ComputationImpl.get_proto(
            computations.tf_computation(lambda: tf.constant(10))))

    make_13 = computation_building_blocks.Block(
        [('x', computation_building_blocks.Call(make_10)),
         ('x',
          computation_building_blocks.Call(
              add_one, computation_building_blocks.Reference('x', tf.int32))),
         ('x',
          computation_building_blocks.Call(
              add_one, computation_building_blocks.Reference('x', tf.int32))),
         ('x',
          computation_building_blocks.Call(
              add_one, computation_building_blocks.Reference('x', tf.int32)))],
        computation_building_blocks.Reference('x', tf.int32))

    make_13_computation = computation_impl.ComputationImpl(
        make_13.proto, context_stack_impl.context_stack)

    self.assertEqual(make_13_computation(), 13)
 def test_basic_functionality_of_block_class(self):
     x = computation_building_blocks.Block([
         ('x',
          computation_building_blocks.Reference('arg',
                                                (tf.int32, tf.int32))),
         ('y',
          computation_building_blocks.Selection(
              computation_building_blocks.Reference('x',
                                                    (tf.int32, tf.int32)),
              index=0))
     ], computation_building_blocks.Reference('y', tf.int32))
     self.assertEqual(str(x.type_signature), 'int32')
     self.assertEqual([(k, v.tff_repr) for k, v in x.locals],
                      [('x', 'arg'), ('y', 'x[0]')])
     self.assertEqual(x.result.tff_repr, 'y')
     self.assertEqual(
         repr(x), 'Block([(\'x\', Reference(\'arg\', '
         'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)]))), '
         '(\'y\', Selection(Reference(\'x\', '
         'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)])), '
         'index=0))], '
         'Reference(\'y\', TensorType(tf.int32)))')
     self.assertEqual(x.tff_repr, '(let x=arg,y=x[0] in y)')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'block')
     self.assertEqual(str(x_proto.block.result), str(x.result.proto))
     for idx, loc_proto in enumerate(x_proto.block.local):
         loc_name, loc_value = x.locals[idx]
         self.assertEqual(loc_proto.name, loc_name)
         self.assertEqual(str(loc_proto.value), str(loc_value.proto))
         self._serialize_deserialize_roundtrip_test(x)
예제 #7
0
 def _traverse_block(comp, transform, context_tree, identifier_seq):
     """Helper function holding traversal logic for block nodes."""
     comp_id = six.next(identifier_seq)
     transformed_locals = []
     if comp.locals:
         first_local_name = comp.locals[0][0]
         first_local_comp = comp.locals[0][1]
         new_value = _transform_postorder_with_symbol_bindings_switch(
             first_local_comp, transform, context_tree, identifier_seq)
         transformed_locals.append((first_local_name, new_value))
         context_tree.ingest_variable_binding(
             name=transformed_locals[0][0],
             value=transformed_locals[0][1],
             mode=MutationMode.CHILD,
             comp_id=comp_id)
     for k in range(1, len(comp.locals)):
         new_value = _transform_postorder_with_symbol_bindings_switch(
             comp.locals[k][1], transform, context_tree, identifier_seq)
         transformed_locals.append((comp.locals[k][0], new_value))
         context_tree.ingest_variable_binding(
             name=transformed_locals[k][0],
             value=transformed_locals[k][1],
             mode=MutationMode.SIBLING)
     transformed_result = _transform_postorder_with_symbol_bindings_switch(
         comp.result, transform, context_tree, identifier_seq)
     transformed_comp = transform(
         computation_building_blocks.Block(transformed_locals,
                                           transformed_result),
         context_tree)
     if comp.locals:
         context_tree.move_to_parent_context()
     return transformed_comp
예제 #8
0
 def _transform(comp):
     if not _should_transform(comp):
         return comp, False
     transformed_comp = computation_building_blocks.Block(
         [(comp.function.parameter_name, comp.argument)],
         comp.function.result)
     return transformed_comp, True
예제 #9
0
  def __call__(self, comp):
    """Counts references to locals under Block and performs inlining.

    If the `comp` argument is a `computation_building_blocks.Block`, `__call__`
    selects the locals on which to perform inlining based on the threshold
    defined in `inlining_threshold` and the snapshot of the calling AST
    before any transformations are executed, stored as `counts`, before
    executing the inlining itself.

    Args:
      comp: The `computation_building_blocks.ComputationBuildingBlock` to be
        checked for the possibility of inlining.

    Returns:
      comp: A transformed version of `comp`, with locals of any of its
        `computation_building_blocks.Block`s which are referenced
        `inlining_threshold` or fewer times replaced with their
        associated values. All local declarations no longer referenced
        in the body are removed.
    """
    self.idx += 1
    if isinstance(comp, (computation_building_blocks.Block)):
      bound_dict = self.counts[self.initial_comp_names[self.idx]]
      values_to_replace = [
          k for k, v in bound_dict.items() if v <= self.inlining_threshold
      ]
      names_and_values = {
          x[0]: x[1] for x in comp.locals if x[0] in values_to_replace
      }

      def _execute_inlining_from_bound_dict(inner_comp):
        """Uses `dict` bound to calling comp to inline as appropriate.

        Args:
          inner_comp: The `computation_building_blocks.ComputationBuildingBlock`
            to potentially inline.

        Returns:
          `computation_building_blocks.ComputationBuildingBlock`, `inner_comp`
          unchanged if `inner_comp` is not a
          `computation_building_blocks.Reference` whose name  appears in
          `bound_dict`; otherwise the appropriate local definition.
        """
        if (isinstance(inner_comp, computation_building_blocks.Reference) and
            names_and_values.get(inner_comp.name)):
          py_typecheck.check_type(
              names_and_values[inner_comp.name],
              computation_building_blocks.ComputationBuildingBlock)
          return names_and_values[inner_comp.name]
        return inner_comp

      remaining_locals = [(name, val)
                          for name, val in comp.locals
                          if name not in values_to_replace]
      return computation_building_blocks.Block(
          remaining_locals,
          transform_postorder(comp.result, _execute_inlining_from_bound_dict))
    else:
      return comp
예제 #10
0
 def test_conflicting_nested_name_inlining(self):
     innermost = computation_building_blocks.Reference('x', tf.int32)
     intermediate_arg = computation_building_blocks.Reference('y', tf.int32)
     item2 = computation_building_blocks.Block([('x', intermediate_arg)],
                                               innermost)
     item1 = computation_building_blocks.Reference('x', tf.int32)
     mediate_tuple = computation_building_blocks.Tuple([item1, item2])
     used = computation_building_blocks.Reference('used', tf.int32)
     used1 = computation_building_blocks.Reference('used1', tf.int32)
     outer_block = computation_building_blocks.Block([('x', used),
                                                      ('y', used1)],
                                                     mediate_tuple)
     self.assertEqual(str(outer_block),
                      '(let x=used,y=used1 in <x,(let x=y in x)>)')
     inlined = transformations.inline_blocks_with_n_referenced_locals(
         outer_block)
     self.assertEqual(str(inlined), '(let  in <used,(let  in used1)>)')
예제 #11
0
 def test_propogates_dependence_up_through_block_locals(self):
   dummy_intrinsic = computation_building_blocks.Intrinsic(
       'dummy_intrinsic', tf.int32)
   integer_reference = computation_building_blocks.Reference('int', tf.int32)
   block = computation_building_blocks.Block([('x', dummy_intrinsic)],
                                             integer_reference)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       block, dummy_intrinsic_predicate)
   self.assertIn(block, dependent_nodes)
예제 #12
0
 def test_scope_snapshot_block_overwrite(self):
     innermost = computation_building_blocks.Reference('x', tf.int32)
     intermediate_arg = computation_building_blocks.Reference('y', tf.int32)
     item2 = computation_building_blocks.Block([('x', intermediate_arg)],
                                               innermost)
     item1 = computation_building_blocks.Reference('x', tf.int32)
     mediate_tuple = computation_building_blocks.Tuple([item1, item2])
     used = computation_building_blocks.Reference('used', tf.int32)
     used1 = computation_building_blocks.Reference('used1', tf.int32)
     outer_block = computation_building_blocks.Block([('x', used),
                                                      ('y', used1)],
                                                     mediate_tuple)
     self.assertEqual(str(outer_block),
                      '(let x=used,y=used1 in <x,(let x=y in x)>)')
     snapshot = transformations.scope_count_snapshot(outer_block)
     self.assertEqual(snapshot[str(item2)], {'x': 1})
     self.assertEqual(snapshot[str(outer_block)], {'x': 1, 'y': 1})
     self.assertIsNone(snapshot.get(str(mediate_tuple)))
def _create_chain_zipped_values(value):
    r"""Creates a chain of called federated zip with two values.

                Block--------
               /             \
  [value=Tuple]               Call
         |                   /    \
         [Comp1,    Intrinsic      Tuple
          Comp2,                   |
          ...]                     [Call,  Sel(n)]
                                   /    \        \
                          Intrinsic      Tuple    Ref(value)
                                         |
                                         [Sel(0),       Sel(1)]
                                                \             \
                                                 Ref(value)    Ref(value)

  NOTE: This function is intended to be used in conjunction with
  `_create_fn_to_append_chain_zipped_values` and will drop the tuple names. The
  names will be added back to the resulting computation when the zipped values
  are mapped to a function that flattens the chain. This nested zip -> flatten
  structure must be used since length of a named tuple type in the TFF type
  system is an element of the type proper. That is, a named tuple type of
  length 2 is a different type than a named tuple type of length 3, they are
  not simply items with the same type and different values, as would be the
  case if you were thinking of these as Python `list`s. It may be better to
  think of named tuple types in TFF as more like `struct`s.

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_types.NamedTupleType` containing at
      least two elements.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
    ValueError: If `value` does not contain at least two elements.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    named_type_signatures = anonymous_tuple.to_elements(value.type_signature)
    length = len(named_type_signatures)
    if length < 2:
        raise ValueError(
            'Expected a value with at least two elements, received {} elements.'
            .format(named_type_signatures))
    ref = computation_building_blocks.Reference('value', value.type_signature)
    symbols = ((ref.name, value), )
    sel_0 = computation_building_blocks.Selection(ref, index=0)
    result = sel_0
    for i in range(1, length):
        sel = computation_building_blocks.Selection(ref, index=i)
        values = computation_building_blocks.Tuple((result, sel))
        result = _create_zip_two_values(values)
    return computation_building_blocks.Block(symbols, result)
예제 #14
0
 def test_simple_block_inlining(self):
     test_arg = computation_building_blocks.Data('test_data', tf.int32)
     result = computation_building_blocks.Reference('test_x',
                                                    test_arg.type_signature)
     simple_block = computation_building_blocks.Block(
         [('test_x', test_arg)], result)
     self.assertEqual(str(simple_block), '(let test_x=test_data in test_x)')
     inlined = transformations.inline_blocks_with_n_referenced_locals(
         simple_block)
     self.assertEqual(str(inlined), '(let  in test_data)')
예제 #15
0
def create_federated_unzip(value):
    r"""Creates a tuple of called federated maps or applies.

                Block
               /     \
  [value=Comp]        Tuple
                      |
                      [Call,                        Call, ...]
                      /    \                       /    \
             Intrinsic      Tuple         Intrinsic      Tuple
                            |                            |
                [Lambda(arg), Ref(value)]    [Lambda(arg), Ref(value)]
                            \                            \
                             Sel(0)                       Sel(1)
                                   \                            \
                                    Ref(arg)                     Ref(arg)

  This function returns a tuple of federated values given a `value` with a
  federated tuple type signature.

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_types.NamedTupleType` containing at
      least one element.

  Returns:
    A `computation_building_blocks.Block`.

  Raises:
    TypeError: If any of the types do not match.
    ValueError: If `value` does not contain any elements.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    named_type_signatures = anonymous_tuple.to_elements(
        value.type_signature.member)
    length = len(named_type_signatures)
    if length == 0:
        raise ValueError(
            'federated_zip is only supported on non-empty tuples.')
    value_ref = computation_building_blocks.Reference('value',
                                                      value.type_signature)
    elements = []
    fn_ref = computation_building_blocks.Reference('arg',
                                                   named_type_signatures)
    for index, (name, _) in enumerate(named_type_signatures):
        sel = computation_building_blocks.Selection(fn_ref, index=index)
        fn = computation_building_blocks.Lambda(fn_ref.name,
                                                fn_ref.type_signature, sel)
        intrinsic = create_federated_map_or_apply(fn, value_ref)
        elements.append((name, intrinsic))
    result = computation_building_blocks.Tuple(elements)
    symbols = ((value_ref.name, value), )
    return computation_building_blocks.Block(symbols, result)
예제 #16
0
def _create_chain_zipped_values(value):
    r"""Creates a chain of called federated zip with two values.

                Block--------
               /             \
  [value=Tuple]               Call
         |                   /    \
         [Comp1,    Intrinsic      Tuple
          Comp2,                   |
          ...]                     [Call,  Sel(n)]
                                   /    \        \
                          Intrinsic      Tuple    Ref(value)
                                         |
                                         [Sel(0),       Sel(1)]
                                                \             \
                                                 Ref(value)    Ref(value)

  NOTE: This function is intended to be used in conjunction with
  `_create_fn_to_append_chain_zipped_values` and will drop the tuple names. The
  names will be added back to the resulting computation when the zipped values
  are mapped to a function that flattens the chain.

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_types.NamedTupleType` containing at
      least two elements.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
    ValueError: If `value` does not contain at least two elements.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    named_type_signatures = anonymous_tuple.to_elements(value.type_signature)
    length = len(named_type_signatures)
    if length < 2:
        raise ValueError(
            'Expected a value with at least two elements, received {} elements.'
            .format(named_type_signatures))
    first_name, _ = named_type_signatures[0]
    ref = computation_building_blocks.Reference('value', value.type_signature)
    symbols = ((ref.name, value), )
    sel_0 = computation_building_blocks.Selection(ref, index=0)
    result = (first_name, sel_0)
    for i in range(1, length):
        name, _ = named_type_signatures[i]
        sel = computation_building_blocks.Selection(ref, index=i)
        values = computation_building_blocks.Tuple((result, (name, sel)))
        result = _create_zip_two_values(values)
    return computation_building_blocks.Block(symbols, result)
예제 #17
0
 def test_inline_conflicting_locals(self):
     arg_comp = computation_building_blocks.Reference(
         'arg', [tf.int32, tf.int32])
     selected = computation_building_blocks.Selection(arg_comp, index=0)
     internal_arg = computation_building_blocks.Reference('arg', tf.int32)
     block = computation_building_blocks.Block([('arg', selected)],
                                               internal_arg)
     lam = computation_building_blocks.Lambda('arg',
                                              arg_comp.type_signature,
                                              block)
     self.assertEqual(str(lam), '(arg -> (let arg=arg[0] in arg))')
     inlined = transformations.inline_blocks_with_n_referenced_locals(lam)
     self.assertEqual(str(inlined), '(arg -> (let  in arg[0]))')
예제 #18
0
 def _extract_from_lambda(comp):
     """Returns a new computation with all intrinsics extracted."""
     if _is_called_intrinsic(comp.result):
         called_intrinsic = comp.result
         name = six.next(name_generator)
         variables = ((name, called_intrinsic), )
         ref = computation_building_blocks.Reference(
             name, called_intrinsic.type_signature)
         if not _contains_unbound_reference(comp.result,
                                            comp.parameter_name):
             fn = computation_building_blocks.Lambda(
                 comp.parameter_name, comp.parameter_type, ref)
             return computation_building_blocks.Block(variables, fn)
         else:
             block = computation_building_blocks.Block(variables, ref)
             return computation_building_blocks.Lambda(
                 comp.parameter_name, comp.parameter_type, block)
     else:
         block = comp.result
         extracted_variables = []
         retained_variables = []
         for name, variable in block.locals:
             names = [n for n, _ in retained_variables]
             if (not _contains_unbound_reference(variable,
                                                 comp.parameter_name)
                     and not _contains_unbound_reference(variable, names)):
                 extracted_variables.append((name, variable))
             else:
                 retained_variables.append((name, variable))
         if retained_variables:
             result = computation_building_blocks.Block(
                 retained_variables, block.result)
         else:
             result = block.result
         fn = computation_building_blocks.Lambda(comp.parameter_name,
                                                 comp.parameter_type,
                                                 result)
         block = computation_building_blocks.Block(extracted_variables, fn)
         return _extract_from_block(block)
예제 #19
0
    def test_replace_called_lambda_does_not_replace_separated_called_lambda(
            self):
        arg = computation_building_blocks.Reference('arg', tf.int32)
        lam = _create_lambda_to_identity(arg.type_signature)
        block = computation_building_blocks.Block([], lam)
        call = computation_building_blocks.Call(block, arg)
        comp = call

        transformed_comp = transformations.replace_called_lambda_with_block(
            comp)

        self.assertEqual(str(transformed_comp), str(comp))
        self.assertEqual(str(transformed_comp), '(let  in (arg -> arg))(arg)')
예제 #20
0
 def _extract_from_block(comp):
   """Returns a new computation with all intrinsics extracted."""
   if _is_called_intrinsic(comp.result):
     called_intrinsic = comp.result
     name = six.next(name_generator)
     variables = comp.locals
     variables.append((name, called_intrinsic))
     result = computation_building_blocks.Reference(
         name, called_intrinsic.type_signature)
     return computation_building_blocks.Block(variables, result)
   elif isinstance(comp.result, computation_building_blocks.Block):
     return computation_building_blocks.Block(comp.locals + comp.result.locals,
                                              comp.result.result)
   else:
     variables = []
     for name, variable in comp.locals:
       if isinstance(variable, computation_building_blocks.Block):
         variables.extend(variable.locals)
         variables.append((name, variable.result))
       else:
         variables.append((name, variable))
     return computation_building_blocks.Block(variables, comp.result)
예제 #21
0
 def test_multiple_inline_for_nested_block(self):
     used1 = computation_building_blocks.Reference('used1', tf.int32)
     used2 = computation_building_blocks.Reference('used2', tf.int32)
     ref = computation_building_blocks.Reference('x', used1.type_signature)
     lower_block = computation_building_blocks.Block([('x', used1)], ref)
     higher_block = computation_building_blocks.Block([('used1', used2)],
                                                      lower_block)
     inlined = transformations.inline_blocks_with_n_referenced_locals(
         higher_block)
     self.assertEqual(str(higher_block),
                      '(let used1=used2 in (let x=used1 in x))')
     self.assertEqual(str(inlined), '(let  in (let  in used2))')
     user_inlined_lower_block = computation_building_blocks.Block(
         [('x', used1)], used1)
     user_inlined_higher_block = computation_building_blocks.Block(
         [('used1', used2)], user_inlined_lower_block)
     self.assertEqual(str(user_inlined_higher_block),
                      '(let used1=used2 in (let x=used1 in used1))')
     inlined_noop = transformations.inline_blocks_with_n_referenced_locals(
         user_inlined_higher_block)
     self.assertEqual(str(inlined_noop),
                      '(let used1=used2 in (let  in used1))')
예제 #22
0
def create_dummy_block(comp, variable_name, variable_type=tf.int32):
  r"""Returns an identity block.

           Block
          /     \
  [x=data]       Comp

  Args:
    comp: The computation to use as the result.
    variable_name: The name of the variable.
    variable_type: The type of the variable.
  """
  data = computation_building_blocks.Data('data', variable_type)
  return computation_building_blocks.Block([(variable_name, data)], comp)
예제 #23
0
  def test_propogates_dependence_into_binding_to_reference(self):
    fed_type = computation_types.FederatedType(tf.int32, placements.CLIENTS)
    ref_to_x = computation_building_blocks.Reference('x', fed_type)
    federated_zero = computation_building_blocks.Intrinsic(
        intrinsic_defs.GENERIC_ZERO.uri, fed_type)

    def federated_zero_predicate(x):
      return isinstance(x, computation_building_blocks.Intrinsic
                       ) and x.uri == intrinsic_defs.GENERIC_ZERO.uri

    block = computation_building_blocks.Block([('x', federated_zero)], ref_to_x)
    dependent_nodes = tree_analysis.extract_nodes_consuming(
        block, federated_zero_predicate)
    self.assertIn(ref_to_x, dependent_nodes)
예제 #24
0
def create_computation_appending(comp1, comp2):
    r"""Returns a block appending `comp2` to `comp1`.

                Block
               /     \
  [comps=Tuple]       Tuple
         |            |
    [Comp, Comp]      [Sel(0), ...,  Sel(0),   Sel(1)]
                             \             \         \
                              Sel(0)        Sel(n)    Ref(comps)
                                    \             \
                                     Ref(comps)    Ref(comps)

  Args:
    comp1: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_type.NamedTupleType`.
    comp2: A `computation_building_blocks.ComputationBuildingBlock` or a named
      computation (a tuple pair of name, computation) representing a single
      element of an `anonymous_tuple.AnonymousTuple`.

  Returns:
    A `computation_building_blocks.Block`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        comp1, computation_building_blocks.ComputationBuildingBlock)
    if isinstance(comp2, computation_building_blocks.ComputationBuildingBlock):
        name2 = None
    elif py_typecheck.is_name_value_pair(
            comp2,
            name_required=False,
            value_type=computation_building_blocks.ComputationBuildingBlock):
        name2, comp2 = comp2
    else:
        raise TypeError('Unexpected tuple element: {}.'.format(comp2))
    comps = computation_building_blocks.Tuple((comp1, comp2))
    ref = computation_building_blocks.Reference('comps', comps.type_signature)
    sel_0 = computation_building_blocks.Selection(ref, index=0)
    elements = []
    named_type_signatures = anonymous_tuple.to_elements(comp1.type_signature)
    for index, (name, _) in enumerate(named_type_signatures):
        sel = computation_building_blocks.Selection(sel_0, index=index)
        elements.append((name, sel))
    sel_1 = computation_building_blocks.Selection(ref, index=1)
    elements.append((name2, sel_1))
    result = computation_building_blocks.Tuple(elements)
    symbols = ((ref.name, comps), )
    return computation_building_blocks.Block(symbols, result)
예제 #25
0
def create_identity_block(variable_name, comp):
    r"""Returns an identity block.

           Block
          /     \
  [x=comp]       Ref(x)

  Args:
    variable_name: The name of the variable.
    comp: The computation to use as the variable.
  """
    ref = computation_building_blocks.Reference(variable_name,
                                                comp.type_signature)
    return computation_building_blocks.Block([(variable_name, comp)], ref)
예제 #26
0
 def _transform(comp):
     """Internal function to break down Call-Lambda and build Block."""
     if not isinstance(comp, computation_building_blocks.Call):
         return comp
     elif not isinstance(comp.function, computation_building_blocks.Lambda):
         return comp
     py_typecheck.check_type(
         comp.argument,
         computation_building_blocks.ComputationBuildingBlock)
     arg = comp.argument
     lam = comp.function
     param_name = lam.parameter_name
     result = lam.result
     return computation_building_blocks.Block([(param_name, arg)], result)
예제 #27
0
 def test_no_inlining_if_referenced_twice(self):
     test_arg = computation_building_blocks.Data('test_data', tf.int32)
     ref1 = computation_building_blocks.Reference('test_x',
                                                  test_arg.type_signature)
     ref2 = computation_building_blocks.Reference('test_x',
                                                  test_arg.type_signature)
     result = computation_building_blocks.Tuple([ref1, ref2])
     simple_block = computation_building_blocks.Block(
         [('test_x', test_arg)], result)
     self.assertEqual(str(simple_block),
                      '(let test_x=test_data in <test_x,test_x>)')
     inlined = transformations.inline_blocks_with_n_referenced_locals(
         simple_block)
     self.assertEqual(str(inlined), str(simple_block))
예제 #28
0
    def _transform_functional_args(comps):
        r"""Transforms the functional computations `comps`.

    Given a computation containing `n` called intrinsics with `m` arguments,
    this function constructs the following computation from the functional
    arguments of the called intrinsic:

                    Block
                   /     \
         [fn=Tuple]       Lambda(arg)
             |                       \
    [Comp(f1), Comp(f2), ...]         Tuple
                                      |
                                 [Call,                  Call, ...]
                                 /    \                 /    \
                           Sel(0)      Sel(0)     Sel(1)      Sel(1)
                          /           /          /           /
                   Ref(fn)    Ref(arg)    Ref(fn)    Ref(arg)

    with one `computation_building_blocks.Call` for each `n`. This computation
    represents one of `m` arguments that should be passed to the call of the
    transformed computation.

    Args:
      comps: a Python list of computations.

    Returns:
      A `computation_building_blocks.Block`.
    """
        functions = computation_building_blocks.Tuple(comps)
        functions_name = six.next(name_generator)
        functions_ref = computation_building_blocks.Reference(
            functions_name, functions.type_signature)
        arg_name = six.next(name_generator)
        arg_type = [element.type_signature.parameter for element in comps]
        arg_ref = computation_building_blocks.Reference(arg_name, arg_type)
        elements = []
        for index in range(len(comps)):
            sel_fn = computation_building_blocks.Selection(functions_ref,
                                                           index=index)
            sel_arg = computation_building_blocks.Selection(arg_ref,
                                                            index=index)
            call = computation_building_blocks.Call(sel_fn, sel_arg)
            elements.append(call)
        calls = computation_building_blocks.Tuple(elements)
        fn = computation_building_blocks.Lambda(arg_ref.name,
                                                arg_ref.type_signature, calls)
        return computation_building_blocks.Block(
            ((functions_ref.name, functions), ), fn)
예제 #29
0
 def _extract_from_call(comp):
   """Returns a new computation with all intrinsics extracted."""
   if _is_called_intrinsic(comp.argument):
     called_intrinsic = comp.argument
     name = six.next(name_generator)
     variables = ((name, called_intrinsic),)
     result = computation_building_blocks.Reference(
         name, called_intrinsic.type_signature)
   else:
     block = comp.argument
     variables = block.locals
     result = block.result
   call = computation_building_blocks.Call(comp.function, result)
   block = computation_building_blocks.Block(variables, call)
   return _extract_from_block(block)
예제 #30
0
    def test_no_reduce_separated_lambda_and_call(self):
        @computations.federated_computation(tf.int32)
        def foo(x):
            return x

        comp = _to_building_block(foo)
        block_wrapped_comp = computation_building_blocks.Block([], comp)
        test_arg = computation_building_blocks.Data('test', tf.int32)
        called_block = computation_building_blocks.Call(
            block_wrapped_comp, test_arg)
        lambda_reduced_comp = transformations.replace_called_lambdas_with_block(
            called_block)
        self.assertEqual(str(called_block),
                         '(let  in (foo_arg -> foo_arg))(test)')
        self.assertEqual(str(called_block), str(lambda_reduced_comp))