def test_construct_setattr_named_tuple_type_replaces_single_element(self):
     good_type = computation_types.NamedTupleType([('a', tf.int32),
                                                   ('b', tf.bool)])
     value_comp = computation_building_blocks.Data('x', tf.int32)
     lam = computation_constructing_utils.construct_named_tuple_setattr_lambda(
         good_type, 'a', value_comp)
     self.assertEqual(
         lam.tff_repr,
         '(let value_comp_placeholder=x in (lambda_arg -> <a=value_comp_placeholder,b=lambda_arg[1]>))'
     )
 def test_returns_federated_map(self):
     ref = computation_building_blocks.Reference('x', tf.int32)
     fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                             ref)
     arg_type = computation_types.FederatedType(tf.int32,
                                                placements.CLIENTS, False)
     arg = computation_building_blocks.Data('y', arg_type)
     comp = computation_constructing_utils.create_federated_map(fn, arg)
     self.assertEqual(comp.tff_repr, 'federated_map(<(x -> x),y>)')
     self.assertEqual(str(comp.type_signature), '{int32}@CLIENTS')
示例#3
0
 def test_simple_block_inlining(self):
     test_arg = computation_building_blocks.Data('test_data', tf.int32)
     result = computation_building_blocks.Reference('test_x',
                                                    test_arg.type_signature)
     simple_block = computation_building_blocks.Block(
         [('test_x', test_arg)], result)
     self.assertEqual(str(simple_block), '(let test_x=test_data in test_x)')
     inlined = transformations.inline_blocks_with_n_referenced_locals(
         simple_block)
     self.assertEqual(str(inlined), '(let  in test_data)')
    def test_replace_called_lambda_replaces_called_lambda(self):
        fn = _create_lambda_to_identity(tf.int32)
        arg = computation_building_blocks.Data('x', tf.int32)
        call = computation_building_blocks.Call(fn, arg)
        comp = call

        transformed_comp = transformations.replace_called_lambda_with_block(
            comp)

        self.assertEqual(comp.tff_repr, '(arg -> arg)(x)')
        self.assertEqual(transformed_comp.tff_repr, '(let arg=x in arg)')
示例#5
0
 def test_construct_setattr_named_tuple_type_leaves_type_signature_unchanged(
         self):
     good_type = computation_types.NamedTupleType([('a', tf.int32),
                                                   (None, tf.float32),
                                                   ('b', tf.bool)])
     value_comp = computation_building_blocks.Data('x', tf.int32)
     lam = computation_constructing_utils.construct_named_tuple_setattr_lambda(
         good_type, 'a', value_comp)
     self.assertTrue(
         type_utils.are_equivalent_types(lam.type_signature.parameter,
                                         lam.type_signature.result))
示例#6
0
 def test_returns_string_for_data(self):
     comp = computation_building_blocks.Data('data', tf.int32)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'data')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, 'data')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     self.assertEqual(structural_string, 'data')
示例#7
0
 def test_raises_type_error_with_none_zero(self):
     value_type = computation_types.FederatedType(tf.int32,
                                                  placements.CLIENTS, False)
     value = computation_building_blocks.Data('v', value_type)
     accumulate_type = computation_types.NamedTupleType(
         (tf.int32, tf.int32))
     accumulate_result = computation_building_blocks.Data('a', tf.int32)
     accumulate = computation_building_blocks.Lambda(
         'x', accumulate_type, accumulate_result)
     merge_type = computation_types.NamedTupleType((tf.int32, tf.int32))
     merge_result = computation_building_blocks.Data('m', tf.int32)
     merge = computation_building_blocks.Lambda('x', merge_type,
                                                merge_result)
     report_ref = computation_building_blocks.Reference('r', tf.int32)
     report = computation_building_blocks.Lambda(report_ref.name,
                                                 report_ref.type_signature,
                                                 report_ref)
     with self.assertRaises(TypeError):
         computation_constructing_utils.create_federated_aggregate(
             value, None, accumulate, merge, report)
示例#8
0
  def test_federated_setattr_call_fails_on_none_value(self):
    named_tuple_type = computation_types.NamedTupleType([('a', tf.int32),
                                                         (None, tf.float32),
                                                         ('b', tf.bool)])
    good_type = computation_types.FederatedType(named_tuple_type,
                                                placement_literals.CLIENTS)
    acceptable_comp = computation_building_blocks.Data('data', good_type)

    with self.assertRaises(TypeError):
      _ = computation_constructing_utils.construct_federated_setattr_call(
          acceptable_comp, 'a', None)
    def test_remove_mapped_or_applied_identity_removes_identity(
            self, uri, type_spec, comp_factory):
        fn = _create_lambda_to_identity(tf.int32)
        arg = computation_building_blocks.Data('x', type_spec)
        call = comp_factory(fn, arg)
        comp = call

        transformed_comp = transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(comp.tff_repr, '{}(<(arg -> arg),x>)'.format(uri))
        self.assertEqual(transformed_comp.tff_repr, 'x')
 def test_returns_string_for_tuple_with_no_names(self):
     data = computation_building_blocks.Data('data', tf.int32)
     comp = computation_building_blocks.Tuple((data, data))
     compact_string = comp.compact_representation()
     self.assertEqual(compact_string, '<data,data>')
     formatted_string = comp.formatted_representation()
     # pyformat: disable
     self.assertEqual(formatted_string, '<\n' '  data,\n' '  data\n' '>')
     # pyformat: enable
     structural_string = comp.structural_representation()
     # pyformat: disable
     self.assertEqual(structural_string, 'Tuple\n' '|\n' '[data, data]')
    def test_remove_mapped_or_applied_identity_does_not_remove_called_lambda(
            self):
        fn = _create_lambda_to_identity(tf.int32)
        arg = computation_building_blocks.Data('x', tf.int32)
        call = computation_building_blocks.Call(fn, arg)
        comp = call

        transformed_comp = transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(comp.tff_repr, '(arg -> arg)(x)')
        self.assertEqual(transformed_comp.tff_repr, '(arg -> arg)(x)')
 def test_does_not_find_aggregate_dependent_on_broadcast(self):
   broadcast = computation_test_utils.create_dummy_called_federated_broadcast()
   value_type = broadcast.type_signature
   zero = computation_building_blocks.Data('zero', value_type.member)
   accumulate_result = computation_building_blocks.Data(
       'accumulate_result', value_type.member)
   accumulate = computation_building_blocks.Lambda(
       'accumulate_parameter', [value_type.member, value_type.member],
       accumulate_result)
   merge_result = computation_building_blocks.Data('merge_result',
                                                   value_type.member)
   merge = computation_building_blocks.Lambda(
       'merge_parameter', [value_type.member, value_type.member], merge_result)
   report_result = computation_building_blocks.Data('report_result',
                                                    value_type.member)
   report = computation_building_blocks.Lambda('report_parameter',
                                               value_type.member,
                                               report_result)
   aggregate_dependent_on_broadcast = computation_constructing_utils.create_federated_aggregate(
       broadcast, zero, accumulate, merge, report)
   tree_analysis.check_broadcast_not_dependent_on_aggregate(
       aggregate_dependent_on_broadcast)
示例#13
0
 def test_no_inlining_if_referenced_twice(self):
     test_arg = computation_building_blocks.Data('test_data', tf.int32)
     ref1 = computation_building_blocks.Reference('test_x',
                                                  test_arg.type_signature)
     ref2 = computation_building_blocks.Reference('test_x',
                                                  test_arg.type_signature)
     result = computation_building_blocks.Tuple([ref1, ref2])
     simple_block = computation_building_blocks.Block(
         [('test_x', test_arg)], result)
     self.assertEqual(str(simple_block),
                      '(let test_x=test_data in <test_x,test_x>)')
     inlined = transformations.inline_blocks_with_n_referenced_locals(
         simple_block)
     self.assertEqual(str(inlined), str(simple_block))
    def test_replace_called_lambda_does_not_replace_separated_called_lambda(
            self):
        fn = _create_lambda_to_identity(tf.int32)
        block = _create_dummy_block(fn)
        arg = computation_building_blocks.Data('x', tf.int32)
        call = computation_building_blocks.Call(block, arg)
        comp = call

        transformed_comp = transformations.replace_called_lambda_with_block(
            comp)

        self.assertEqual(transformed_comp.tff_repr, comp.tff_repr)
        self.assertEqual(transformed_comp.tff_repr,
                         '(let local=data in (arg -> arg))(x)')
示例#15
0
def create_dummy_block(comp, variable_name, variable_type=tf.int32):
  r"""Returns an identity block.

           Block
          /     \
  [x=data]       Comp

  Args:
    comp: The computation to use as the result.
    variable_name: The name of the variable.
    variable_type: The type of the variable.
  """
  data = computation_building_blocks.Data('data', variable_type)
  return computation_building_blocks.Block([(variable_name, data)], comp)
示例#16
0
def create_identity_block_with_dummy_data(variable_name,
                                          variable_type=tf.int32):
  r"""Returns an identity block with a dummy `Data` computation.

           Block
          /     \
  [x=data]       Ref(x)

  Args:
    variable_name: The name of the variable.
    variable_type: The type of the variable.
  """
  data = computation_building_blocks.Data('data', variable_type)
  return create_identity_block(variable_name, data)
def create_dummy_called_federated_broadcast(value_type=tf.int32):
    r"""Returns a dummy called federated broadcast.

                Call
               /    \
  federated_map      data

  Args:
    value_type: The type of the parameter.
  """
    federated_type = computation_types.FederatedType(value_type,
                                                     placements.SERVER)
    value = computation_building_blocks.Data('data', federated_type)
    return computation_constructing_utils.create_federated_broadcast(value)
示例#18
0
    def test_no_reduce_separated_lambda_and_call(self):
        @computations.federated_computation(tf.int32)
        def foo(x):
            return x

        comp = _to_building_block(foo)
        block_wrapped_comp = computation_building_blocks.Block([], comp)
        test_arg = computation_building_blocks.Data('test', tf.int32)
        called_block = computation_building_blocks.Call(
            block_wrapped_comp, test_arg)
        lambda_reduced_comp = transformations.replace_called_lambdas_with_block(
            called_block)
        self.assertEqual(str(called_block),
                         '(let  in (foo_arg -> foo_arg))(test)')
        self.assertEqual(str(called_block), str(lambda_reduced_comp))
 def test_basic_functionality_of_data_class(self):
     x = computation_building_blocks.Data(
         '/tmp/mydata', computation_types.SequenceType(tf.int32))
     self.assertEqual(str(x.type_signature), 'int32*')
     self.assertEqual(x.uri, '/tmp/mydata')
     self.assertEqual(
         repr(x),
         'Data(\'/tmp/mydata\', SequenceType(TensorType(tf.int32)))')
     self.assertEqual(x.tff_repr, '/tmp/mydata')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'data')
     self.assertEqual(x_proto.data.uri, x.uri)
     self._serialize_deserialize_roundtrip_test(x)
    def test_replace_chained_federated_maps_does_not_replace_one_federated_map(
            self):
        fn = _create_lambda_to_identity(tf.int32)
        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)
        arg = computation_building_blocks.Data('x', arg_type)
        call = _create_called_federated_map(fn, arg)
        comp = call

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(transformed_comp.tff_repr, comp.tff_repr)
        self.assertEqual(transformed_comp.tff_repr,
                         'federated_map(<(arg -> arg),x>)')
示例#21
0
 def test_returns_federated_aggregate(self):
     value_type = computation_types.FederatedType(tf.int32,
                                                  placements.CLIENTS, False)
     value = computation_building_blocks.Data('v', value_type)
     zero = computation_building_blocks.Data('z', tf.int32)
     accumulate_type = computation_types.NamedTupleType(
         (tf.int32, tf.int32))
     accumulate_result = computation_building_blocks.Data('a', tf.int32)
     accumulate = computation_building_blocks.Lambda(
         'x', accumulate_type, accumulate_result)
     merge_type = computation_types.NamedTupleType((tf.int32, tf.int32))
     merge_result = computation_building_blocks.Data('m', tf.int32)
     merge = computation_building_blocks.Lambda('x', merge_type,
                                                merge_result)
     report_ref = computation_building_blocks.Reference('r', tf.int32)
     report = computation_building_blocks.Lambda(report_ref.name,
                                                 report_ref.type_signature,
                                                 report_ref)
     comp = computation_constructing_utils.create_federated_aggregate(
         value, zero, accumulate, merge, report)
     self.assertEqual(
         comp.tff_repr,
         'federated_aggregate(<v,z,(x -> a),(x -> m),(r -> r)>)')
     self.assertEqual(str(comp.type_signature), 'int32@SERVER')
def create_dummy_called_sequence_map(parameter_name, parameter_type=tf.int32):
    r"""Returns a dummy called sequence map.

               Call
              /    \
  sequence_map      data

  Args:
    parameter_name: The name of the parameter.
    parameter_type: The type of the parameter.
  """
    fn = create_identity_function(parameter_name, parameter_type)
    arg_type = computation_types.SequenceType(parameter_type)
    arg = computation_building_blocks.Data('data', arg_type)
    return computation_constructing_utils.create_sequence_map(fn, arg)
示例#23
0
 def test_remove_mapped_or_applied_identity_removes_identity(
         self, uri, data_type):
     data = computation_building_blocks.Data('x', data_type)
     identity_arg = computation_building_blocks.Reference('arg', tf.float32)
     identity_lam = computation_building_blocks.Lambda(
         'arg', tf.float32, identity_arg)
     arg_tuple = computation_building_blocks.Tuple([identity_lam, data])
     function_type = computation_types.FunctionType(
         [arg_tuple.type_signature[0], arg_tuple.type_signature[1]],
         arg_tuple.type_signature[1])
     intrinsic = computation_building_blocks.Intrinsic(uri, function_type)
     call = computation_building_blocks.Call(intrinsic, arg_tuple)
     self.assertEqual(str(call), '{}(<(arg -> arg),x>)'.format(uri))
     reduced = transformations.remove_mapped_or_applied_identity(call)
     self.assertEqual(str(reduced), 'x')
    def test_remove_mapped_or_applied_identity_removes_multiple_identities(
            self):
        fn = _create_lambda_to_identity(tf.int32)
        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)
        arg = computation_building_blocks.Data('x', arg_type)
        call = _create_chained_called_federated_map(fn, arg, 2)
        comp = call

        transformed_comp = transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(
            comp.tff_repr,
            'federated_map(<(arg -> arg),federated_map(<(arg -> arg),x>)>)')
        self.assertEqual(transformed_comp.tff_repr, 'x')
    def test_remove_mapped_or_applied_identity_does_not_remove_other_intrinsic(
            self):
        fn = _create_lambda_to_identity(tf.int32)
        arg = computation_building_blocks.Data('x', tf.int32)
        intrinsic_type = computation_types.FunctionType(
            [fn.type_signature, arg.type_signature], arg.type_signature)
        intrinsic = computation_building_blocks.Intrinsic(
            'dummy', intrinsic_type)
        tup = computation_building_blocks.Tuple((fn, arg))
        call = computation_building_blocks.Call(intrinsic, tup)
        comp = call

        transformed_comp = transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(comp.tff_repr, 'dummy(<(arg -> arg),x>)')
        self.assertEqual(transformed_comp.tff_repr, 'dummy(<(arg -> arg),x>)')
示例#26
0
 def test_is_anon_tuple_with_py_container(self):
   self.assertTrue(
       type_utils.is_anon_tuple_with_py_container(
           anonymous_tuple.AnonymousTuple([('a', 0.0)]),
           computation_types.NamedTupleTypeWithPyContainerType(
               [('a', tf.float32)], dict)))
   self.assertFalse(
       type_utils.is_anon_tuple_with_py_container(
           value_impl.ValueImpl(
               computation_building_blocks.Data('nothing', tf.int32),
               context_stack_impl.context_stack),
           computation_types.NamedTupleTypeWithPyContainerType(
               [('a', tf.float32)], dict)))
   self.assertFalse(
       type_utils.is_anon_tuple_with_py_container(
           anonymous_tuple.AnonymousTuple([('a', 0.0)]),
           computation_types.NamedTupleType([('a', tf.float32)])))
    def test_replace_intrinsic_replaces_multiple_intrinsics(self):
        fn = _create_lambda_to_dummy_intrinsic(tf.int32)
        arg = computation_building_blocks.Data('x', tf.int32)
        call = _create_chained_call(fn, arg, 2)
        comp = call
        uri = 'dummy'
        body = lambda x: x

        transformed_comp = transformations.replace_intrinsic_with_callable(
            comp, uri, body, context_stack_impl.context_stack)

        self.assertEqual(comp.tff_repr,
                         '(arg -> dummy(arg))((arg -> dummy(arg))(x))')
        self.assertEqual(
            transformed_comp.tff_repr,
            '(arg -> (dummy_arg -> dummy_arg)(arg))((arg -> (dummy_arg -> dummy_arg)(arg))(x))'
        )
 def test_returns_string_for_call_with_arg(self):
     ref = computation_building_blocks.Reference('a', tf.int32)
     fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                             ref)
     arg = computation_building_blocks.Data('data', tf.int32)
     comp = computation_building_blocks.Call(fn, arg)
     compact_string = comp.compact_representation()
     self.assertEqual(compact_string, '(a -> a)(data)')
     formatted_string = comp.formatted_representation()
     self.assertEqual(formatted_string, '(a -> a)(data)')
     structural_string = comp.structural_representation()
     # pyformat: disable
     self.assertEqual(
         structural_string, '          Call\n'
         '         /    \\\n'
         'Lambda(a)      data\n'
         '|\n'
         'Ref(a)')
def create_dummy_called_federated_map(parameter_name, parameter_type=tf.int32):
  r"""Returns a dummy called federated map.

                Call
               /    \
  federated_map      Tuple
                     |
                     [Lambda(x), data]
                      |
                      Ref(x)

  Args:
    parameter_name: The name of the parameter.
    parameter_type: The type of the parameter.
  """
  fn = create_identity_function(parameter_name, parameter_type)
  arg_type = computation_types.FederatedType(parameter_type, placements.CLIENTS)
  arg = computation_building_blocks.Data('data', arg_type)
  return computation_constructing_utils.create_federated_map(fn, arg)
    def test_replace_chained_federated_maps_does_not_replace_separated_federated_maps(
            self):
        fn_1 = _create_lambda_to_identity(tf.int32)
        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)
        arg = computation_building_blocks.Data('x', arg_type)
        call_1 = _create_called_federated_map(fn_1, arg)
        block = _create_dummy_block(call_1)
        fn_2 = _create_lambda_to_identity(tf.int32)
        call_2 = _create_called_federated_map(fn_2, block)
        comp = call_2

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(transformed_comp.tff_repr, comp.tff_repr)
        self.assertEqual(
            transformed_comp.tff_repr,
            'federated_map(<(arg -> arg),(let local=data in federated_map(<(arg -> arg),x>))>)'
        )