示例#1
0
 def test_raises_type_error_with_none_accumulate(self):
     value_type = computation_types.FederatedType(tf.int32,
                                                  placements.CLIENTS, False)
     value = computation_building_blocks.Data('v', value_type)
     zero = computation_building_blocks.Data('z', tf.int32)
     merge_type = computation_types.NamedTupleType((tf.int32, tf.int32))
     merge_result = computation_building_blocks.Data('m', tf.int32)
     merge = computation_building_blocks.Lambda('x', merge_type,
                                                merge_result)
     report_ref = computation_building_blocks.Reference('r', tf.int32)
     report = computation_building_blocks.Lambda(report_ref.name,
                                                 report_ref.type_signature,
                                                 report_ref)
     with self.assertRaises(TypeError):
         computation_constructing_utils.create_federated_aggregate(
             value, zero, None, merge, report)
示例#2
0
 def test_two_tuple_zip_with_named_client_all_equal_int_and_bool(self):
     test_ref = computation_building_blocks.Reference(
         'test',
         computation_types.NamedTupleType([
             ('a',
              computation_types.FederatedType(tf.int32, placements.CLIENTS,
                                              True)),
             ('b',
              computation_types.FederatedType(tf.bool, placements.CLIENTS,
                                              True))
         ]))
     zipped = value_utils.zip_two_tuple(
         value_impl.to_value(test_ref, None, _context_stack),
         _context_stack)
     self.assertEqual(str(zipped.type_signature),
                      '{<a=int32,b=bool>}@CLIENTS')
示例#3
0
 def _extract_from_selection(comp):
   """Returns a new computation with all intrinsics extracted."""
   if _is_called_intrinsic(comp.source):
     called_intrinsic = comp.source
     name = six.next(name_generator)
     variables = ((name, called_intrinsic),)
     result = computation_building_blocks.Reference(
         name, called_intrinsic.type_signature)
   else:
     block = comp.source
     variables = block.locals
     result = block.result
   selection = computation_building_blocks.Selection(
       result, name=comp.name, index=comp.index)
   block = computation_building_blocks.Block(variables, selection)
   return _extract_from_block(block)
示例#4
0
def _create_lambda_to_identity(dtype):
    r"""Creates a lambda to return the argument.

  Lambda(x)
        \
         Reference(x)

  Args:
    dtype: The type of the argument.

  Returns:
    An instance of `computation_building_blocks.Lambda`.
  """
    arg = computation_building_blocks.Reference('arg', dtype)
    return computation_building_blocks.Lambda(arg.name, arg.type_signature,
                                              arg)
示例#5
0
 def test_value_impl_with_lambda(self):
     arg_name = 'arg'
     arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)),
                 ('x', tf.int32)]
     result_value = (lambda arg: arg.f(arg.f(arg.x)))(value_impl.ValueImpl(
         computation_building_blocks.Reference(arg_name, arg_type),
         context_stack_impl.context_stack))
     x = value_impl.ValueImpl(
         computation_building_blocks.Lambda(
             arg_name, arg_type,
             value_impl.ValueImpl.get_comp(result_value)),
         context_stack_impl.context_stack)
     self.assertIsInstance(x, value_base.Value)
     self.assertEqual(str(x.type_signature),
                      '(<f=(int32 -> int32),x=int32> -> int32)')
     self.assertEqual(str(x), '(arg -> arg.f(arg.f(arg.x)))')
示例#6
0
 def _extract_from_tuple(comp):
   """Returns a new computation with all intrinsics extracted."""
   variables = []
   elements = []
   for name, element in anonymous_tuple.to_elements(comp):
     if _is_called_intrinsic_or_block(element):
       variable_name = six.next(name_generator)
       variables.append((variable_name, element))
       ref = computation_building_blocks.Reference(variable_name,
                                                   element.type_signature)
       elements.append((name, ref))
     else:
       elements.append((name, element))
   tup = computation_building_blocks.Tuple(elements)
   block = computation_building_blocks.Block(variables, tup)
   return _extract_from_block(block)
def create_dummy_called_intrinsic(parameter_name, parameter_type=tf.int32):
  r"""Returns a dummy called intrinsic.

            Call
           /    \
  intrinsic      Ref(x)

  Args:
    parameter_name: The name of the parameter.
    parameter_type: The type of the parameter.
  """
  intrinsic_type = computation_types.FunctionType(parameter_type,
                                                  parameter_type)
  intrinsic = computation_building_blocks.Intrinsic('intrinsic', intrinsic_type)
  ref = computation_building_blocks.Reference(parameter_name, parameter_type)
  return computation_building_blocks.Call(intrinsic, ref)
示例#8
0
 def test_flatten_function(self, n):
     input_reference = computation_building_blocks.Reference(
         'test', [tf.int32] * n)
     input_function = computation_building_blocks.Lambda(
         'test', input_reference.type_signature, input_reference)
     type_to_add = (None, computation_types.to_type(tf.int32))
     input_type = computation_types.NamedTupleType(
         [input_reference.type_signature, type_to_add])
     desired_output_type = computation_types.to_type([tf.int32] * (n + 1))
     desired_function_type = computation_types.FunctionType(
         input_type, desired_output_type)
     new_func = value_utils.flatten_first_index(
         value_impl.to_value(input_function, None, _context_stack),
         type_to_add, _context_stack)
     self.assertEqual(str(new_func.type_signature),
                      str(desired_function_type))
示例#9
0
    def test_propogates_dependence_into_binding_to_reference(self):
        fed_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)
        ref_to_x = computation_building_blocks.Reference('x', fed_type)
        federated_zero = computation_building_blocks.Intrinsic(
            intrinsic_defs.GENERIC_ZERO.uri, fed_type)

        def federated_zero_predicate(x):
            return isinstance(x, computation_building_blocks.Intrinsic
                              ) and x.uri == intrinsic_defs.GENERIC_ZERO.uri

        block = computation_building_blocks.Block([('x', federated_zero)],
                                                  ref_to_x)
        dependent_nodes = tree_analysis.extract_nodes_consuming(
            block, federated_zero_predicate)
        self.assertIn(ref_to_x, dependent_nodes)
 def test_basic_functionality_of_selection_class(self):
     x = computation_building_blocks.Reference('foo', [('bar', tf.int32),
                                                       ('baz', tf.bool)])
     y = computation_building_blocks.Selection(x, name='bar')
     self.assertEqual(y.name, 'bar')
     self.assertEqual(y.index, None)
     self.assertEqual(str(y.type_signature), 'int32')
     self.assertEqual(
         repr(y), 'Selection(Reference(\'foo\', NamedTupleType(['
         '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))'
         ', name=\'bar\')')
     self.assertEqual(computation_building_blocks.compact_representation(y),
                      'foo.bar')
     z = computation_building_blocks.Selection(x, name='baz')
     self.assertEqual(str(z.type_signature), 'bool')
     self.assertEqual(computation_building_blocks.compact_representation(z),
                      'foo.baz')
     with self.assertRaises(ValueError):
         _ = computation_building_blocks.Selection(x, name='bak')
     x0 = computation_building_blocks.Selection(x, index=0)
     self.assertEqual(x0.name, None)
     self.assertEqual(x0.index, 0)
     self.assertEqual(str(x0.type_signature), 'int32')
     self.assertEqual(
         repr(x0), 'Selection(Reference(\'foo\', NamedTupleType(['
         '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))'
         ', index=0)')
     self.assertEqual(
         computation_building_blocks.compact_representation(x0), 'foo[0]')
     x1 = computation_building_blocks.Selection(x, index=1)
     self.assertEqual(str(x1.type_signature), 'bool')
     self.assertEqual(
         computation_building_blocks.compact_representation(x1), 'foo[1]')
     with self.assertRaises(ValueError):
         _ = computation_building_blocks.Selection(x, index=2)
     with self.assertRaises(ValueError):
         _ = computation_building_blocks.Selection(x, index=-1)
     y_proto = y.proto
     self.assertEqual(type_serialization.deserialize_type(y_proto.type),
                      y.type_signature)
     self.assertEqual(y_proto.WhichOneof('computation'), 'selection')
     self.assertEqual(str(y_proto.selection.source), str(x.proto))
     self.assertEqual(y_proto.selection.name, 'bar')
     self._serialize_deserialize_roundtrip_test(y)
     self._serialize_deserialize_roundtrip_test(z)
     self._serialize_deserialize_roundtrip_test(x0)
     self._serialize_deserialize_roundtrip_test(x1)
 def test_returns_string_for_call_with_arg(self):
     ref = computation_building_blocks.Reference('a', tf.int32)
     fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                             ref)
     arg = computation_building_blocks.Data('data', tf.int32)
     comp = computation_building_blocks.Call(fn, arg)
     compact_string = comp.compact_representation()
     self.assertEqual(compact_string, '(a -> a)(data)')
     formatted_string = comp.formatted_representation()
     self.assertEqual(formatted_string, '(a -> a)(data)')
     structural_string = comp.structural_representation()
     # pyformat: disable
     self.assertEqual(
         structural_string, '          Call\n'
         '         /    \\\n'
         'Lambda(a)      data\n'
         '|\n'
         'Ref(a)')
示例#12
0
 def test_raises_type_error_with_none_value(self):
     zero = computation_building_blocks.Data('z', tf.int32)
     accumulate_type = computation_types.NamedTupleType(
         (tf.int32, tf.int32))
     accumulate_result = computation_building_blocks.Data('a', tf.int32)
     accumulate = computation_building_blocks.Lambda(
         'x', accumulate_type, accumulate_result)
     merge_type = computation_types.NamedTupleType((tf.int32, tf.int32))
     merge_result = computation_building_blocks.Data('m', tf.int32)
     merge = computation_building_blocks.Lambda('x', merge_type,
                                                merge_result)
     report_ref = computation_building_blocks.Reference('r', tf.int32)
     report = computation_building_blocks.Lambda(report_ref.name,
                                                 report_ref.type_signature,
                                                 report_ref)
     with self.assertRaises(TypeError):
         computation_constructing_utils.create_federated_aggregate(
             None, zero, accumulate, merge, report)
示例#13
0
def construct_federated_getitem_comp(comp, key):
    """Function to construct computation for `federated_apply` of `__getitem__`.

  Constructs a `computation_building_blocks.ComputationBuildingBlock`
  which selects `key` from its argument, of type `comp.type_signature.member`,
  of type `computation_types.NamedTupleType`.

  Args:
    comp: Instance of `computation_building_blocks.ComputationBuildingBlock`
      with type signature `computation_types.FederatedType` whose `member`
      attribute is of type `computation_types.NamedTupleType`.
    key: Instance of `int` or `slice`, key used to grab elements from the member
      of `comp`. implementation of slicing for `ValueImpl` objects with
      `type_signature` `computation_types.NamedTupleType`.

  Returns:
    Instance of `computation_building_blocks.Lambda` which grabs slice
      according to `key` of its argument.
  """
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(comp.type_signature,
                            computation_types.FederatedType)
    py_typecheck.check_type(comp.type_signature.member,
                            computation_types.NamedTupleType)
    py_typecheck.check_type(key, (int, slice))
    apply_input = computation_building_blocks.Reference(
        'x', comp.type_signature.member)
    if isinstance(key, int):
        selected = computation_building_blocks.Selection(apply_input,
                                                         index=key)
    else:
        elems = anonymous_tuple.to_elements(comp.type_signature.member)
        index_range = range(*key.indices(len(elems)))
        elem_list = []
        for k in index_range:
            elem_list.append(
                (elems[k][0],
                 computation_building_blocks.Selection(apply_input, index=k)))
        selected = computation_building_blocks.Tuple(elem_list)
    apply_lambda = computation_building_blocks.Lambda(
        'x', apply_input.type_signature, selected)
    return apply_lambda
示例#14
0
 def _transform(comp):
     """Internal transform function."""
     if not _should_transform(comp):
         return comp
     map_arg = comp.argument[1].argument[1]
     inner_arg = computation_building_blocks.Reference(
         'inner_arg', map_arg.type_signature.member)
     inner_fn = comp.argument[1].argument[0]
     inner_call = computation_building_blocks.Call(inner_fn, inner_arg)
     outer_fn = comp.argument[0]
     outer_call = computation_building_blocks.Call(outer_fn, inner_call)
     map_lambda = computation_building_blocks.Lambda(
         inner_arg.name, inner_arg.type_signature, outer_call)
     map_tuple = computation_building_blocks.Tuple([map_lambda, map_arg])
     map_intrinsic_type = computation_types.FunctionType(
         map_tuple.type_signature, comp.function.type_signature.result)
     map_intrinsic = computation_building_blocks.Intrinsic(
         comp.function.uri, map_intrinsic_type)
     return computation_building_blocks.Call(map_intrinsic, map_tuple)
 def test_returns_string_for_block(self):
     data = computation_building_blocks.Data('data', tf.int32)
     ref = computation_building_blocks.Reference('c', tf.int32)
     comp = computation_building_blocks.Block((('a', data), ('b', data)),
                                              ref)
     compact_string = comp.compact_representation()
     self.assertEqual(compact_string, '(let a=data,b=data in c)')
     formatted_string = comp.formatted_representation()
     # pyformat: disable
     self.assertEqual(formatted_string, '(let\n'
                      '  a=data,\n'
                      '  b=data\n'
                      ' in c)')
     # pyformat: enable
     structural_string = comp.structural_representation()
     # pyformat: disable
     self.assertEqual(
         structural_string, '                 Block\n'
         '                /     \\\n'
         '[a=data, b=data]       Ref(c)')
def _create_lambda_to_identity(type_spec):
    r"""Creates a lambda to return the argument.

  Lambda
        \
         Ref(arg)

  Args:
    type_spec: The type of the argument.

  Returns:
    A `computation_building_blocks.Lambda`.

  Raises:
    TypeError: If `type_spec` is not a `tf.dtypes.DType`.
  """
    py_typecheck.check_type(type_spec, tf.dtypes.DType)
    arg = computation_building_blocks.Reference('arg', type_spec)
    return computation_building_blocks.Lambda(arg.name, arg.type_signature,
                                              arg)
    def test_replace_chained_federated_maps_replaces_federated_maps_with_different_types(
            self):
        fn_1 = _create_lambda_to_dummy_cast(tf.int32, tf.float32)
        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)
        arg = computation_building_blocks.Reference('x', arg_type)
        call_1 = _create_called_federated_map(fn_1, arg)
        fn_2 = _create_lambda_to_identity(tf.float32)
        call_2 = _create_called_federated_map(fn_2, call_1)
        comp = call_2

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(
            comp.tff_repr,
            'federated_map(<(arg -> arg),federated_map(<(arg -> data),x>)>)')
        self.assertEqual(
            transformed_comp.tff_repr,
            'federated_map(<(arg -> (arg -> arg)((arg -> data)(arg))),x>)')
示例#18
0
 def test_remove_mapped_or_applied_identity_removes_nested_identity(
         self, uri, data_type):
     data = computation_building_blocks.Data('x', data_type)
     identity_arg = computation_building_blocks.Reference('arg', tf.float32)
     identity_lam = computation_building_blocks.Lambda(
         'arg', tf.float32, identity_arg)
     arg_tuple = computation_building_blocks.Tuple([identity_lam, data])
     function_type = computation_types.FunctionType(
         [arg_tuple.type_signature[0], arg_tuple.type_signature[1]],
         arg_tuple.type_signature[1])
     intrinsic = computation_building_blocks.Intrinsic(uri, function_type)
     call = computation_building_blocks.Call(intrinsic, arg_tuple)
     tuple_wrapped_call = computation_building_blocks.Tuple([call])
     lambda_wrapped_tuple = computation_building_blocks.Lambda(
         'y', tf.int32, tuple_wrapped_call)
     self.assertEqual(str(lambda_wrapped_tuple),
                      '(y -> <{}(<(arg -> arg),x>)>)'.format(uri))
     reduced = transformations.remove_mapped_or_applied_identity(
         lambda_wrapped_tuple)
     self.assertEqual(str(reduced), '(y -> <x>)')
示例#19
0
 def _transform(comp, context_tree):
   """Renames References in `comp` to unique names."""
   if isinstance(comp, computation_building_blocks.Reference):
     new_name = context_tree.get_payload_with_name(comp.name).new_name
     return computation_building_blocks.Reference(new_name,
                                                  comp.type_signature,
                                                  comp.context), True
   elif isinstance(comp, computation_building_blocks.Block):
     new_locals = []
     for name, val in comp.locals:
       context_tree.walk_down_one_variable_binding()
       new_name = context_tree.get_payload_with_name(name).new_name
       new_locals.append((new_name, val))
     return computation_building_blocks.Block(new_locals, comp.result), True
   elif isinstance(comp, computation_building_blocks.Lambda):
     context_tree.walk_down_one_variable_binding()
     new_name = context_tree.get_payload_with_name(
         comp.parameter_name).new_name
     return computation_building_blocks.Lambda(new_name, comp.parameter_type,
                                               comp.result), True
   return comp, False
    def test_returns_string_for_comp_with_left_overhang(self):
        fn_type = computation_types.FunctionType(tf.int32, tf.int32)
        fn = computation_building_blocks.Reference('a', fn_type)
        proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            lambda: tf.constant(1), None, context_stack_impl.context_stack)
        compiled = computation_building_blocks.CompiledComputation(
            proto, 'bbbbb')
        arg = computation_building_blocks.Call(compiled)

        comp = computation_building_blocks.Call(fn, arg)
        compact_string = comp.compact_representation()
        self.assertEqual(compact_string, 'a(comp#bbbbb())')
        formatted_string = comp.formatted_representation()
        self.assertEqual(formatted_string, 'a(comp#bbbbb())')
        structural_string = comp.structural_representation()
        # pyformat: disable
        self.assertEqual(
            structural_string, '           Call\n'
            '          /    \\\n'
            '    Ref(a)      Call\n'
            '               /\n'
            'Compiled(bbbbb)')
示例#21
0
 def _extract_from_block(comp):
   """Returns a new computation with all intrinsics extracted."""
   if _is_called_intrinsic(comp.result):
     called_intrinsic = comp.result
     name = six.next(name_generator)
     variables = comp.locals
     variables.append((name, called_intrinsic))
     result = computation_building_blocks.Reference(
         name, called_intrinsic.type_signature)
     return computation_building_blocks.Block(variables, result)
   elif isinstance(comp.result, computation_building_blocks.Block):
     return computation_building_blocks.Block(comp.locals + comp.result.locals,
                                              comp.result.result)
   else:
     variables = []
     for name, variable in comp.locals:
       if isinstance(variable, computation_building_blocks.Block):
         variables.extend(variable.locals)
         variables.append((name, variable.result))
       else:
         variables.append((name, variable))
     return computation_building_blocks.Block(variables, comp.result)
示例#22
0
 def _extract_from_lambda(comp):
     """Returns a new computation with all intrinsics extracted."""
     if _is_called_intrinsic(comp.result):
         called_intrinsic = comp.result
         name = six.next(name_generator)
         variables = ((name, called_intrinsic), )
         ref = computation_building_blocks.Reference(
             name, called_intrinsic.type_signature)
         if not _contains_unbound_reference(comp.result,
                                            comp.parameter_name):
             fn = computation_building_blocks.Lambda(
                 comp.parameter_name, comp.parameter_type, ref)
             return computation_building_blocks.Block(variables, fn)
         else:
             block = computation_building_blocks.Block(variables, ref)
             return computation_building_blocks.Lambda(
                 comp.parameter_name, comp.parameter_type, block)
     else:
         block = comp.result
         extracted_variables = []
         retained_variables = []
         for name, variable in block.locals:
             names = [n for n, _ in retained_variables]
             if (not _contains_unbound_reference(variable,
                                                 comp.parameter_name)
                     and not _contains_unbound_reference(variable, names)):
                 extracted_variables.append((name, variable))
             else:
                 retained_variables.append((name, variable))
         if retained_variables:
             result = computation_building_blocks.Block(
                 retained_variables, block.result)
         else:
             result = block.result
         fn = computation_building_blocks.Lambda(comp.parameter_name,
                                                 comp.parameter_type,
                                                 result)
         block = computation_building_blocks.Block(extracted_variables, fn)
         return _extract_from_block(block)
示例#23
0
    def test_remove_mapped_or_applied_identity_does_not_remove_other_intrinsic(
            self):
        data_type = tf.int32
        uri = 'dummy'
        data = computation_building_blocks.Data('x', data_type)
        identity_arg = computation_building_blocks.Reference('arg', tf.float32)
        identity_lam = computation_building_blocks.Lambda(
            'arg', tf.float32, identity_arg)
        arg_tuple = computation_building_blocks.Tuple([identity_lam, data])
        function_type = computation_types.FunctionType(
            [arg_tuple.type_signature[0], arg_tuple.type_signature[1]],
            arg_tuple.type_signature[1])
        intrinsic = computation_building_blocks.Intrinsic(uri, function_type)
        call = computation_building_blocks.Call(intrinsic, arg_tuple)
        comp = call

        transformed_comp = transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(str(comp), '{}(<(arg -> arg),x>)'.format(uri))
        self.assertEqual(str(transformed_comp),
                         '{}(<(arg -> arg),x>)'.format(uri))
def _construct_naming_function(tuple_type_to_name, names_to_add):
    """Private function to construct lambda naming a given tuple type.

  Args:
    tuple_type_to_name: Instance of `computation_types.NamedTupleType`, the type
      of the argument which we wish to name.
    names_to_add: Python `list` or `tuple`, the names we wish to give to
      `tuple_type_to_name`.

  Returns:
    An instance of `computation_building_blocks.Lambda` representing a function
    which will take an argument of type `tuple_type_to_name` and return a tuple
    with the same elements, but with names in `names_to_add` attached.

  Raises:
    ValueError: If `tuple_type_to_name` and `names_to_add` have different
    lengths.
  """
    py_typecheck.check_type(tuple_type_to_name,
                            computation_types.NamedTupleType)
    if len(names_to_add) != len(tuple_type_to_name):
        raise ValueError(
            'Number of elements in `names_to_add` must match number of element in '
            'the named tuple type `tuple_type_to_name`; here, `names_to_add` has '
            '{} elements and `tuple_type_to_name` has {}.'.format(
                len(names_to_add), len(tuple_type_to_name)))
    naming_lambda_arg = computation_building_blocks.Reference(
        'x', tuple_type_to_name)

    def _create_tuple_element(i):
        return (names_to_add[i],
                computation_building_blocks.Selection(naming_lambda_arg,
                                                      index=i))

    named_result = computation_building_blocks.Tuple(
        [_create_tuple_element(k) for k in range(len(names_to_add))])
    return computation_building_blocks.Lambda('x',
                                              naming_lambda_arg.type_signature,
                                              named_result)
示例#25
0
    def test_replace_chained_federated_maps_does_not_replace_one_federated_maps(
            self):
        map_arg_type = computation_types.FederatedType(tf.int32,
                                                       placements.CLIENTS)
        map_arg = computation_building_blocks.Reference('arg', map_arg_type)
        inner_lambda = _create_lambda_to_add_one(map_arg.type_signature.member)
        inner_call = _create_call_to_federated_map(inner_lambda, map_arg)
        map_lambda = computation_building_blocks.Lambda(
            map_arg.name, map_arg.type_signature, inner_call)
        comp = map_lambda
        uri = intrinsic_defs.FEDERATED_MAP.uri

        self.assertEqual(_get_number_of_intrinsics(comp, uri), 1)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl([(1)]), [2])

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 1)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl([(1)]), [2])
 def test_basic_functionality_of_lambda_class(self):
     arg_name = 'arg'
     arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)),
                 ('x', tf.int32)]
     arg = computation_building_blocks.Reference(arg_name, arg_type)
     arg_f = computation_building_blocks.Selection(arg, name='f')
     arg_x = computation_building_blocks.Selection(arg, name='x')
     x = computation_building_blocks.Lambda(
         arg_name, arg_type,
         computation_building_blocks.Call(
             arg_f, computation_building_blocks.Call(arg_f, arg_x)))
     self.assertEqual(str(x.type_signature),
                      '(<f=(int32 -> int32),x=int32> -> int32)')
     self.assertEqual(x.parameter_name, arg_name)
     self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>')
     self.assertEqual(
         computation_building_blocks.compact_representation(x.result),
         'arg.f(arg.f(arg.x))')
     arg_type_repr = (
         'NamedTupleType(['
         '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), '
         '(\'x\', TensorType(tf.int32))])')
     self.assertEqual(
         repr(x), 'Lambda(\'arg\', {0}, '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format(
             arg_type_repr))
     self.assertEqual(computation_building_blocks.compact_representation(x),
                      '(arg -> arg.f(arg.f(arg.x)))')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'lambda')
     self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name)
     self.assertEqual(str(getattr(x_proto, 'lambda').result),
                      str(x.result.proto))
     self._serialize_deserialize_roundtrip_test(x)
示例#27
0
def construct_federated_getattr_comp(comp, name):
    """Function to construct computation for `federated_apply` of `__getattr__`.

  Constructs a `computation_building_blocks.ComputationBuildingBlock`
  which selects `name` from its argument, of type `comp.type_signature.member`,
  an instance of `computation_types.NamedTupleType`.

  Args:
    comp: Instance of `computation_building_blocks.ComputationBuildingBlock`
      with type signature `computation_types.FederatedType` whose `member`
      attribute is of type `computation_types.NamedTupleType`.
    name: String name of attribute to grab.

  Returns:
    Instance of `computation_building_blocks.Lambda` which grabs attribute
      according to `name` of its argument.
  """
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(comp.type_signature,
                            computation_types.FederatedType)
    py_typecheck.check_type(comp.type_signature.member,
                            computation_types.NamedTupleType)
    py_typecheck.check_type(name, six.string_types)
    element_names = [
        x for x, _ in anonymous_tuple.to_elements(comp.type_signature.member)
    ]
    if name not in element_names:
        raise ValueError(
            'The federated value {} has no element of name {}'.format(
                comp, name))
    apply_input = computation_building_blocks.Reference(
        'x', comp.type_signature.member)
    selected = computation_building_blocks.Selection(apply_input, name=name)
    apply_lambda = computation_building_blocks.Lambda(
        'x', apply_input.type_signature, selected)
    return apply_lambda
示例#28
0
 def test_returns_string_for_comp_with_right_overhang(self):
     ref = computation_building_blocks.Reference('a', tf.int32)
     data = computation_building_blocks.Data('data', tf.int32)
     tup = computation_building_blocks.Tuple([ref, data, data, data, data])
     sel = computation_building_blocks.Selection(tup, index=0)
     fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                             sel)
     comp = computation_building_blocks.Call(fn, data)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string,
                      '(a -> <a,data,data,data,data>[0])(data)')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     # pyformat: disable
     self.assertEqual(
         formatted_string, '(a -> <\n'
         '  a,\n'
         '  data,\n'
         '  data,\n'
         '  data,\n'
         '  data\n'
         '>[0])(data)')
     # pyformat: enable
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(
         structural_string, '          Call\n'
         '         /    \\\n'
         'Lambda(a)      data\n'
         '|\n'
         'Sel(0)\n'
         '|\n'
         'Tuple\n'
         '|\n'
         '[Ref(a), data, data, data, data]')
示例#29
0
def _create_lambda_to_cast(dtype1, dtype2):
    r"""Creates a computation to TensorFlow cast from dtype1 to dtype2.

  Lambda
        \
         Call
        /    \
  Compiled   Reference
  Computation

  Where `CompiledComputation` is a TensorFlow computation casting
  from `dtype1` to `dtype2`.

  The `dtype` arguments can be either instances of `tf.dtypes.DType` or
  `computation_types.TensorType`, but in the latter case the `tf.dtypes.DType`
  of these tensors will be extracted.

  Args:
    dtype1: The type of the argument.
    dtype2: The type to cast the argument to.

  Returns:
    An instance of `computation_building_blocks.Lambda` wrapping a function that
    casts TensorFlow dtype1 to dtype2.
  """
    if isinstance(dtype1, computation_types.TensorType):
        dtype1 = dtype1.dtype
    if isinstance(dtype2, computation_types.TensorType):
        dtype2 = dtype2.dtype
    py_typecheck.check_type(dtype1, tf.dtypes.DType)
    py_typecheck.check_type(dtype2, tf.dtypes.DType)
    arg = computation_building_blocks.Reference('arg', dtype1)
    tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        lambda x: tf.cast(x, dtype2), dtype1, context_stack_impl.context_stack)
    compiled_comp = computation_building_blocks.CompiledComputation(tf_comp)
    call = computation_building_blocks.Call(compiled_comp, arg)
    return computation_building_blocks.Lambda(arg.name, dtype1, call)
示例#30
0
 def test_returns_federated_aggregate(self):
     value_type = computation_types.FederatedType(tf.int32,
                                                  placements.CLIENTS, False)
     value = computation_building_blocks.Data('v', value_type)
     zero = computation_building_blocks.Data('z', tf.int32)
     accumulate_type = computation_types.NamedTupleType(
         (tf.int32, tf.int32))
     accumulate_result = computation_building_blocks.Data('a', tf.int32)
     accumulate = computation_building_blocks.Lambda(
         'x', accumulate_type, accumulate_result)
     merge_type = computation_types.NamedTupleType((tf.int32, tf.int32))
     merge_result = computation_building_blocks.Data('m', tf.int32)
     merge = computation_building_blocks.Lambda('x', merge_type,
                                                merge_result)
     report_ref = computation_building_blocks.Reference('r', tf.int32)
     report = computation_building_blocks.Lambda(report_ref.name,
                                                 report_ref.type_signature,
                                                 report_ref)
     comp = computation_constructing_utils.create_federated_aggregate(
         value, zero, accumulate, merge, report)
     self.assertEqual(
         comp.tff_repr,
         'federated_aggregate(<v,z,(x -> a),(x -> m),(r -> r)>)')
     self.assertEqual(str(comp.type_signature), 'int32@SERVER')