Esempio n. 1
0
def create_whimsy_block(comp, variable_name, variable_type=tf.int32):
    r"""Returns an identity block.

           Block
          /     \
  [x=data]       Comp

  Args:
    comp: The computation to use as the result.
    variable_name: The name of the variable.
    variable_type: The type of the variable.
  """
    data = building_blocks.Data('data', variable_type)
    return building_blocks.Block([(variable_name, data)], comp)
Esempio n. 2
0
def create_identity_block_with_dummy_data(variable_name,
                                          variable_type=tf.int32):
    r"""Returns an identity block with a dummy `Data` computation.

           Block
          /     \
  [x=data]       Ref(x)

  Args:
    variable_name: The name of the variable.
    variable_type: The type of the variable.
  """
    data = building_blocks.Data('data', variable_type)
    return create_identity_block(variable_name, data)
Esempio n. 3
0
def create_dummy_called_federated_broadcast(value_type=tf.int32):
    r"""Returns a dummy called federated broadcast.

                Call
               /    \
  federated_map      data

  Args:
    value_type: The type of the value.
  """
    federated_type = computation_types.FederatedType(value_type,
                                                     placements.SERVER)
    value = building_blocks.Data('data', federated_type)
    return building_block_factory.create_federated_broadcast(value)
Esempio n. 4
0
def create_whimsy_called_federated_sum(value_type=tf.int32):
    r"""Returns a whimsy called federated sum.

                Call
               /    \
  federated_sum      data

  Args:
    value_type: The type of the value.
  """
    federated_type = computation_types.FederatedType(value_type,
                                                     placements.CLIENTS)
    value = building_blocks.Data('data', federated_type)
    return building_block_factory.create_federated_sum(value)
Esempio n. 5
0
def create_dummy_called_federated_aggregate(accumulate_parameter_name,
                                            merge_parameter_name,
                                            report_parameter_name,
                                            value_type=tf.int32):
  r"""Returns a dummy called federated aggregate.

                      Call
                     /    \
  federated_aggregate      Tuple
                           |
                           [data, data, Lambda(x), Lambda(x), Lambda(x)]
                                        |          |          |
                                        data       data       data

  Args:
    accumulate_parameter_name: The name of the accumulate parameter.
    merge_parameter_name: The name of the merge parameter.
    report_parameter_name: The name of the report parameter.
    value_type: The TFF type of the value to be aggregated, placed at
      CLIENTS.
  """
  federated_value_type = computation_types.FederatedType(
      value_type, placements.CLIENTS)
  value = building_blocks.Data('data', federated_value_type)
  zero = building_blocks.Data('data', tf.float32)
  accumulate_type = computation_types.NamedTupleType((tf.float32, value_type))
  accumulate_result = building_blocks.Data('data', tf.float32)
  accumulate = building_blocks.Lambda(accumulate_parameter_name,
                                      accumulate_type, accumulate_result)
  merge_type = computation_types.NamedTupleType((tf.float32, tf.float32))
  merge_result = building_blocks.Data('data', tf.float32)
  merge = building_blocks.Lambda(merge_parameter_name, merge_type, merge_result)
  report_result = building_blocks.Data('data', tf.bool)
  report = building_blocks.Lambda(report_parameter_name, tf.float32,
                                  report_result)
  return building_block_factory.create_federated_aggregate(
      value, zero, accumulate, merge, report)
Esempio n. 6
0
    def test_does_not_remove_called_lambda(self):
        fn = building_block_test_utils.create_identity_function('a', tf.int32)
        arg = building_blocks.Data('data', tf.int32)
        call = building_blocks.Call(fn, arg)
        comp = call

        transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(transformed_comp.compact_representation(),
                         comp.compact_representation())
        self.assertEqual(transformed_comp.compact_representation(),
                         '(a -> a)(data)')
        self.assertEqual(transformed_comp.type_signature, comp.type_signature)
        self.assertFalse(modified)
Esempio n. 7
0
    def test_single_level_block(self):
        ref = building_blocks.Reference('a', tf.int32)
        data = building_blocks.Data('data', tf.int32)
        block = building_blocks.Block((('a', data), ('a', ref), ('a', ref)),
                                      ref)

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            block)

        self.assertEqual(block.compact_representation(),
                         '(let a=data,a=a,a=a in a)')
        self.assertEqual(transformed_comp.compact_representation(),
                         '(let a=data,_var1=a,_var2=_var1 in _var2)')
        tree_analysis.check_has_unique_names(transformed_comp)
        self.assertTrue(modified)
Esempio n. 8
0
 def test_returns_string_for_tuple_with_names(self):
     data = building_blocks.Data('data', tf.int32)
     comp = building_blocks.Tuple((('a', data), ('b', data)))
     compact_string = comp.compact_representation()
     self.assertEqual(compact_string, '<a=data,b=data>')
     formatted_string = comp.formatted_representation()
     # pyformat: disable
     self.assertEqual(formatted_string, '<\n'
                      '  a=data,\n'
                      '  b=data\n'
                      '>')
     # pyformat: enable
     structural_string = comp.structural_representation()
     # pyformat: disable
     self.assertEqual(structural_string, 'Tuple\n' '|\n' '[a=data, b=data]')
 def test_basic_functionality_of_data_class(self):
     x = building_blocks.Data('/tmp/mydata',
                              computation_types.SequenceType(tf.int32))
     self.assertEqual(str(x.type_signature), 'int32*')
     self.assertEqual(x.uri, '/tmp/mydata')
     self.assertEqual(
         repr(x),
         'Data(\'/tmp/mydata\', SequenceType(TensorType(tf.int32)))')
     self.assertEqual(x.compact_representation(), '/tmp/mydata')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'data')
     self.assertEqual(x_proto.data.uri, x.uri)
     self._serialize_deserialize_roundtrip_test(x)
Esempio n. 10
0
def create_whimsy_called_sequence_map(parameter_name, parameter_type=tf.int32):
    r"""Returns a whimsy called sequence map.

               Call
              /    \
  sequence_map      data

  Args:
    parameter_name: The name of the parameter.
    parameter_type: The type of the parameter.
  """
    fn = create_identity_function(parameter_name, parameter_type)
    arg_type = computation_types.SequenceType(parameter_type)
    arg = building_blocks.Data('data', arg_type)
    return building_block_factory.create_sequence_map(fn, arg)
Esempio n. 11
0
  def test_returns_string_for_call_with_arg(self):
    fn_type = computation_types.FunctionType(tf.int32, tf.int32)
    fn = building_blocks.Reference('a', fn_type)
    arg = building_blocks.Data('data', tf.int32)
    comp = building_blocks.Call(fn, arg)

    self.assertEqual(comp.compact_representation(), 'a(data)')
    self.assertEqual(comp.formatted_representation(), 'a(data)')
    # pyformat: disable
    self.assertEqual(
        comp.structural_representation(),
        '       Call\n'
        '      /    \\\n'
        'Ref(a)      data'
    )
Esempio n. 12
0
    def test_returns_string_for_block(self):
        data = building_blocks.Data('data', tf.int32)
        ref = building_blocks.Reference('c', tf.int32)
        comp = building_blocks.Block((('a', data), ('b', data)), ref)

        self.assertEqual(comp.compact_representation(),
                         '(let a=data,b=data in c)')
        # pyformat: disable
        self.assertEqual(comp.formatted_representation(), '(let\n'
                         '  a=data,\n'
                         '  b=data\n'
                         ' in c)')
        self.assertEqual(
            comp.structural_representation(), '                 Block\n'
            '                /     \\\n'
            '[a=data, b=data]       Ref(c)')
Esempio n. 13
0
    def test_nested_blocks(self):
        x_ref = building_blocks.Reference('a', tf.int32)
        data = building_blocks.Data('data', tf.int32)
        block1 = building_blocks.Block([('a', data), ('a', x_ref)], x_ref)
        block2 = building_blocks.Block([('a', data), ('a', x_ref)], block1)

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            block2)

        self.assertEqual(block2.compact_representation(),
                         '(let a=data,a=a in (let a=data,a=a in a))')
        self.assertEqual(
            transformed_comp.compact_representation(),
            '(let a=data,_var1=a in (let _var2=data,_var3=_var2 in _var3))')
        tree_analysis.check_has_unique_names(transformed_comp)
        self.assertTrue(modified)
Esempio n. 14
0
def data(uri: str, type_spec: computation_types.Type):
  """Constructs a TFF `data` computation with the given URI and TFF type.

  Args:
    uri: A string (`str`) URI of the data.
    type_spec: An instance of `tff.Type` that represents the type of this data.

  Returns:
    A representation of the data with the given URI and TFF type in the body of
    a federated computation.

  Raises:
    TypeError: If the arguments are not of the types specified above.
  """
  py_typecheck.check_type(uri, str)
  type_spec = computation_types.to_type(type_spec)
  return value_impl.to_value(building_blocks.Data(uri, type_spec), type_spec)
Esempio n. 15
0
    def test_removes_chained_federated_maps(self):
        fn = building_block_test_utils.create_identity_function('a', tf.int32)
        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)
        arg = building_blocks.Data('data', arg_type)
        call = _create_chained_whimsy_federated_maps([fn, fn], arg)
        comp = call

        transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(
            comp.compact_representation(),
            'federated_map(<(a -> a),federated_map(<(a -> a),data>)>)')
        self.assertEqual(transformed_comp.compact_representation(), 'data')
        self.assertEqual(transformed_comp.type_signature, comp.type_signature)
        self.assertTrue(modified)
 def test_with_structure_replacing_federated_map(self):
   function_type = computation_types.FunctionType(tf.int32, tf.int32)
   tuple_ref = building_blocks.Reference('arg', [
       function_type,
       tf.int32,
   ])
   fn = building_blocks.Selection(tuple_ref, index=0)
   arg = building_blocks.Selection(tuple_ref, index=1)
   called_fn = building_blocks.Call(fn, arg)
   concrete_fn = building_blocks.Lambda(
       'x', tf.int32, building_blocks.Reference('x', tf.int32))
   concrete_arg = building_blocks.Data('a', tf.int32)
   arg_tuple = building_blocks.Tuple([concrete_fn, concrete_arg])
   generated_structure = building_blocks.Block([('arg', arg_tuple)], called_fn)
   lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks(
       generated_structure)
   self.assertTrue(modified)
   self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
Esempio n. 17
0
    def test_removes_federated_map_with_named_result(self):
        parameter_type = [('a', tf.int32), ('b', tf.int32)]
        fn = building_block_test_utils.create_identity_function(
            'c', parameter_type)
        arg_type = computation_types.FederatedType(parameter_type,
                                                   placements.CLIENTS)
        arg = building_blocks.Data('data', arg_type)
        call = building_block_factory.create_federated_map(fn, arg)
        comp = call

        transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(comp.compact_representation(),
                         'federated_map(<(c -> c),data>)')
        self.assertEqual(transformed_comp.compact_representation(), 'data')
        self.assertEqual(transformed_comp.type_signature, comp.type_signature)
        self.assertTrue(modified)
Esempio n. 18
0
    def test_nested_lambdas(self):
        data = building_blocks.Data('data', tf.int32)
        input1 = building_blocks.Reference('a', data.type_signature)
        first_level_call = building_blocks.Call(
            building_blocks.Lambda('a', input1.type_signature, input1), data)
        input2 = building_blocks.Reference('b',
                                           first_level_call.type_signature)
        second_level_call = building_blocks.Call(
            building_blocks.Lambda('b', input2.type_signature, input2),
            first_level_call)

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            second_level_call)

        self.assertEqual(transformed_comp.compact_representation(),
                         '(b -> b)((a -> a)(data))')
        tree_analysis.check_has_unique_names(transformed_comp)
        self.assertFalse(modified)
 def test_with_higher_level_lambdas(self):
   self.skipTest('b/146904968')
   data = building_blocks.Data('a', tf.int32)
   dummy = building_blocks.Reference('z', tf.int32)
   lowest_lambda = building_blocks.Lambda(
       'z', tf.int32,
       building_blocks.Tuple([dummy,
                              building_blocks.Reference('x', tf.int32)]))
   middle_lambda = building_blocks.Lambda('x', tf.int32, lowest_lambda)
   lam_arg = building_blocks.Reference('x', middle_lambda.type_signature)
   rez = building_blocks.Call(lam_arg, data)
   left_lambda = building_blocks.Lambda('x', middle_lambda.type_signature, rez)
   higher_call = building_blocks.Call(left_lambda, middle_lambda)
   high_call = building_blocks.Call(higher_call, data)
   lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks(
       high_call)
   self.assertTrue(modified)
   self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
Esempio n. 20
0
  def test_returns_string_for_struct_with_no_names(self):
    data = building_blocks.Data('data', tf.int32)
    comp = building_blocks.Struct([data, data])

    self.assertEqual(comp.compact_representation(), '<data,data>')
    # pyformat: disable
    self.assertEqual(
        comp.formatted_representation(),
        '<\n'
        '  data,\n'
        '  data\n'
        '>'
    )
    self.assertEqual(
        comp.structural_representation(),
        'Struct\n'
        '|\n'
        '[data, data]'
    )
Esempio n. 21
0
    def test_blocks_nested_inside_of_locals(self):
        data = building_blocks.Data('data', tf.int32)
        lower_block = building_blocks.Block([('a', data)], data)
        middle_block = building_blocks.Block([('a', lower_block)], data)
        higher_block = building_blocks.Block([('a', middle_block)], data)
        y_ref = building_blocks.Reference('a', tf.int32)
        lower_block_with_y_ref = building_blocks.Block([('a', y_ref)], data)
        middle_block_with_y_ref = building_blocks.Block(
            [('a', lower_block_with_y_ref)], data)
        higher_block_with_y_ref = building_blocks.Block(
            [('a', middle_block_with_y_ref)], data)
        multiple_bindings_highest_block = building_blocks.Block(
            [('a', higher_block),
             ('a', higher_block_with_y_ref)], higher_block_with_y_ref)

        transformed_comp = self.assert_transforms(
            multiple_bindings_highest_block,
            'uniquify_names_blocks_nested_inside_of_locals.expected')
        tree_analysis.check_has_unique_names(transformed_comp)
Esempio n. 22
0
def create_dummy_called_federated_map(parameter_name, parameter_type=tf.int32):
  r"""Returns a dummy called federated map.

                Call
               /    \
  federated_map      Tuple
                     |
                     [Lambda(x), data]
                      |
                      Ref(x)

  Args:
    parameter_name: The name of the parameter.
    parameter_type: The type of the parameter.
  """
  fn = create_identity_function(parameter_name, parameter_type)
  arg_type = computation_types.FederatedType(parameter_type, placements.CLIENTS)
  arg = building_blocks.Data('data', arg_type)
  return building_block_factory.create_federated_map(fn, arg)
Esempio n. 23
0
  def test_replaces_chained_intrinsics(self):
    fn = test_utils.create_lambda_to_dummy_called_intrinsic(parameter_name='a')
    arg = building_blocks.Data('data', tf.int32)
    call = test_utils.create_chained_calls([fn, fn], arg)
    comp = call
    uri = 'intrinsic'
    body = lambda x: x

    transformed_comp, modified = value_transformations.replace_intrinsics_with_callable(
        comp, uri, body, context_stack_impl.context_stack)

    self.assertEqual(comp.compact_representation(),
                     '(a -> intrinsic(a))((a -> intrinsic(a))(data))')
    self.assertEqual(
        transformed_comp.compact_representation(),
        '(a -> (intrinsic_arg -> intrinsic_arg)(a))((a -> (intrinsic_arg -> intrinsic_arg)(a))(data))'
    )
    self.assertEqual(transformed_comp.type_signature, comp.type_signature)
    self.assertTrue(modified)
Esempio n. 24
0
def create_whimsy_called_federated_apply(parameter_name,
                                         parameter_type=tf.int32):
  r"""Returns a whimsy called federated apply.

                  Call
                 /    \
  federated_apply      Tuple
                       |
                       [Lambda(x), data]
                        |
                        Ref(x)

  Args:
    parameter_name: The name of the parameter.
    parameter_type: The type of the parameter.
  """
  fn = create_identity_function(parameter_name, parameter_type)
  arg_type = computation_types.FederatedType(parameter_type, placements.SERVER)
  arg = building_blocks.Data('data', arg_type)
  return building_block_factory.create_federated_apply(fn, arg)
 def test_with_multiple_reference_indirection(self):
   identity_lam = building_blocks.Lambda(
       'x', tf.int32, building_blocks.Reference('x', tf.int32))
   tuple_wrapping_ref = building_blocks.Tuple(
       [building_blocks.Reference('a', identity_lam.type_signature)])
   selection_from_ref = building_blocks.Selection(
       building_blocks.Reference('b', tuple_wrapping_ref.type_signature),
       index=0)
   data = building_blocks.Data('a', tf.int32)
   called_lambda_with_indirection = building_blocks.Call(
       building_blocks.Reference('c', selection_from_ref.type_signature), data)
   blk = building_blocks.Block([
       ('a', identity_lam),
       ('b', tuple_wrapping_ref),
       ('c', selection_from_ref),
   ], called_lambda_with_indirection)
   lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks(
       blk)
   self.assertTrue(modified)
   self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
Esempio n. 26
0
  def test_handles_federated_broadcasts_nested_in_tuple(self):
    first_broadcast = compiler_test_utils.create_dummy_called_federated_broadcast(
    )
    packed_broadcast = building_blocks.Tuple([
        building_blocks.Data(
            'a',
            computation_types.FederatedType(
                computation_types.TensorType(tf.int32), placements.SERVER)),
        first_broadcast
    ])
    sel = building_blocks.Selection(packed_broadcast, index=0)
    second_broadcast = building_block_factory.create_federated_broadcast(sel)
    comp = building_blocks.Lambda('a', tf.int32, second_broadcast)
    uri = [intrinsic_defs.FEDERATED_BROADCAST.uri]

    before, after = transformations.force_align_and_split_by_intrinsics(
        comp, uri)

    self.assertIsInstance(before, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
    self.assertIsInstance(after, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
Esempio n. 27
0
    def test_block_lambda_block_lambda(self):
        x_ref = building_blocks.Reference('a', tf.int32)
        inner_lambda = building_blocks.Lambda('a', tf.int32, x_ref)
        called_lambda = building_blocks.Call(inner_lambda, x_ref)
        lower_block = building_blocks.Block([('a', x_ref), ('a', x_ref)],
                                            called_lambda)
        second_lambda = building_blocks.Lambda('a', tf.int32, lower_block)
        second_call = building_blocks.Call(second_lambda, x_ref)
        data = building_blocks.Data('data', tf.int32)
        last_block = building_blocks.Block([('a', data), ('a', x_ref)],
                                           second_call)

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            last_block)

        self.assertEqual(
            last_block.compact_representation(),
            '(let a=data,a=a in (a -> (let a=a,a=a in (a -> a)(a)))(a))')
        self.assertEqual(
            transformed_comp.compact_representation(),
            '(let a=data,_var1=a in (_var2 -> (let _var3=_var2,_var4=_var3 in (_var5 -> _var5)(_var4)))(_var1))'
        )
        tree_analysis.check_has_unique_names(transformed_comp)
        self.assertTrue(modified)
Esempio n. 28
0
    def test_compiles_lambda_under_federated_comp_to_tf(self):
        ref_to_x = building_blocks.Reference(
            'x', computation_types.StructType([tf.int32, tf.float32]))
        identity_lambda = building_blocks.Lambda(ref_to_x.name,
                                                 ref_to_x.type_signature,
                                                 ref_to_x)
        federated_data = building_blocks.Data(
            'a',
            computation_types.FederatedType(
                computation_types.StructType([tf.int32, tf.float32]),
                placements.SERVER))
        applied = building_block_factory.create_federated_apply(
            identity_lambda, federated_data)

        transformed = compiler.compile_local_subcomputations_to_tensorflow(
            applied)

        self.assertIsInstance(transformed, building_blocks.Call)
        self.assertIsInstance(transformed.function, building_blocks.Intrinsic)
        self.assertIsInstance(transformed.argument[0],
                              building_blocks.CompiledComputation)
        self.assertEqual(transformed.argument[1], federated_data)
        self.assertEqual(transformed.argument[0].type_signature,
                         identity_lambda.type_signature)
Esempio n. 29
0
 def test_returns_string_for_comp_with_right_overhang(self):
   ref = building_blocks.Reference('a', tf.int32)
   data = building_blocks.Data('data', tf.int32)
   tup = building_blocks.Tuple([ref, data, data, data, data])
   sel = building_blocks.Selection(tup, index=0)
   fn = building_blocks.Lambda(ref.name, ref.type_signature, sel)
   comp = building_blocks.Call(fn, data)
   compact_string = comp.compact_representation()
   self.assertEqual(compact_string, '(a -> <a,data,data,data,data>[0])(data)')
   formatted_string = comp.formatted_representation()
   # pyformat: disable
   self.assertEqual(
       formatted_string,
       '(a -> <\n'
       '  a,\n'
       '  data,\n'
       '  data,\n'
       '  data,\n'
       '  data\n'
       '>[0])(data)'
   )
   # pyformat: enable
   structural_string = comp.structural_representation()
   # pyformat: disable
   self.assertEqual(
       structural_string,
       '          Call\n'
       '         /    \\\n'
       'Lambda(a)      data\n'
       '|\n'
       'Sel(0)\n'
       '|\n'
       'Tuple\n'
       '|\n'
       '[Ref(a), data, data, data, data]'
   )
Esempio n. 30
0
 def test_ok_lambda_binding_of_new_variable(self):
     y_ref = building_blocks.Reference('y', tf.int32)
     lambda_1 = building_blocks.Lambda('y', tf.int32, y_ref)
     x_data = building_blocks.Data('x', tf.int32)
     single_block = building_blocks.Block([('x', x_data)], lambda_1)
     tree_analysis.check_has_unique_names(single_block)