예제 #1
0
 def test_raises_type_error(self):
     with self.assertRaises(TypeError):
         computation_building_blocks.compact_representation(None)
     with self.assertRaises(TypeError):
         computation_building_blocks.formatted_representation(None)
     with self.assertRaises(TypeError):
         computation_building_blocks.structural_representation(None)
예제 #2
0
    def test_returns_string_for_comp_with_left_overhang(self):
        fn_type = computation_types.FunctionType(tf.int32, tf.int32)
        fn = computation_building_blocks.Reference('a', fn_type)
        proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            lambda: tf.constant(1), None, context_stack_impl.context_stack)
        compiled = computation_building_blocks.CompiledComputation(
            proto, 'bbbbb')
        arg = computation_building_blocks.Call(compiled)

        comp = computation_building_blocks.Call(fn, arg)
        compact_string = computation_building_blocks.compact_representation(
            comp)
        self.assertEqual(compact_string, 'a(comp#bbbbb())')
        formatted_string = computation_building_blocks.formatted_representation(
            comp)
        self.assertEqual(formatted_string, 'a(comp#bbbbb())')
        structural_string = computation_building_blocks.structural_representation(
            comp)
        # pyformat: disable
        self.assertEqual(
            structural_string, '           Call\n'
            '          /    \\\n'
            '    Ref(a)      Call\n'
            '               /\n'
            'Compiled(bbbbb)')
예제 #3
0
 def test_returns_string_for_federated_map(self):
     comp = computation_test_utils.create_dummy_called_federated_map(
         parameter_name='a')
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'federated_map(<(a -> a),data>)')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     # pyformat: disable
     self.assertEqual(formatted_string, 'federated_map(<\n'
                      '  (a -> a),\n'
                      '  data\n'
                      '>)')
     # pyformat: enable
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(
         structural_string, '              Call\n'
         '             /    \\\n'
         'federated_map      Tuple\n'
         '                   |\n'
         '                   [Lambda(a), data]\n'
         '                    |\n'
         '                    Ref(a)')
예제 #4
0
 def test_returns_string_for_federated_aggregate(self):
     comp = computation_test_utils.create_dummy_called_federated_aggregate(
         accumulate_parameter_name='a',
         merge_parameter_name='b',
         report_parameter_name='c')
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(
         compact_string,
         'federated_aggregate(<data,data,(a -> data),(b -> data),(c -> data)>)'
     )
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     # pyformat: disable
     self.assertEqual(
         formatted_string, 'federated_aggregate(<\n'
         '  data,\n'
         '  data,\n'
         '  (a -> data),\n'
         '  (b -> data),\n'
         '  (c -> data)\n'
         '>)')
     # pyformat: enable
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(
         structural_string, '                    Call\n'
         '                   /    \\\n'
         'federated_aggregate      Tuple\n'
         '                         |\n'
         '                         [data, data, Lambda(a), Lambda(b), Lambda(c)]\n'
         '                                      |          |          |\n'
         '                                      data       data       data')
예제 #5
0
 def test_returns_string_for_reference(self):
     comp = computation_building_blocks.Reference('a', tf.int32)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'a')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, 'a')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     self.assertEqual(structural_string, 'Ref(a)')
예제 #6
0
 def test_returns_string_for_placement(self):
     comp = computation_building_blocks.Placement(placements.CLIENTS)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'CLIENTS')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, 'CLIENTS')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     self.assertEqual(structural_string, 'Placement')
예제 #7
0
 def test_returns_string_for_intrinsic(self):
     comp = computation_building_blocks.Intrinsic('intrinsic', tf.int32)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'intrinsic')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, 'intrinsic')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     self.assertEqual(structural_string, 'intrinsic')
예제 #8
0
 def test_returns_string_for_compiled_computation(self):
     proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
         lambda: tf.constant(1), None, context_stack_impl.context_stack)
     comp = computation_building_blocks.CompiledComputation(proto, 'a')
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'comp#a')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, 'comp#a')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     self.assertEqual(structural_string, 'Compiled(a)')
예제 #9
0
 def test_returns_string_for_selection_with_index(self):
     ref = computation_building_blocks.Reference('a', (('b', tf.int32),
                                                       ('c', tf.bool)))
     comp = computation_building_blocks.Selection(ref, index=0)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'a[0]')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, 'a[0]')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(structural_string, 'Sel(0)\n' '|\n' 'Ref(a)')
예제 #10
0
 def test_returns_string_for_lambda(self):
     ref = computation_building_blocks.Reference('a', tf.int32)
     comp = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                               ref)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, '(a -> a)')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, '(a -> a)')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(structural_string, 'Lambda(a)\n' '|\n' 'Ref(a)')
예제 #11
0
 def test_returns_string_for_tuple_with_no_names(self):
     data = computation_building_blocks.Data('data', tf.int32)
     comp = computation_building_blocks.Tuple((data, data))
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, '<data,data>')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     # pyformat: disable
     self.assertEqual(formatted_string, '<\n' '  data,\n' '  data\n' '>')
     # pyformat: enable
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(structural_string, 'Tuple\n' '|\n' '[data, data]')
예제 #12
0
 def test_returns_string_for_call_with_arg(self):
     ref = computation_building_blocks.Reference('a', tf.int32)
     fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                             ref)
     arg = computation_building_blocks.Data('data', tf.int32)
     comp = computation_building_blocks.Call(fn, arg)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, '(a -> a)(data)')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, '(a -> a)(data)')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(
         structural_string, '          Call\n'
         '         /    \\\n'
         'Lambda(a)      data\n'
         '|\n'
         'Ref(a)')
예제 #13
0
 def test_returns_string_for_block(self):
     data = computation_building_blocks.Data('data', tf.int32)
     ref = computation_building_blocks.Reference('c', tf.int32)
     comp = computation_building_blocks.Block((('a', data), ('b', data)),
                                              ref)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, '(let a=data,b=data in c)')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     # pyformat: disable
     self.assertEqual(formatted_string, '(let\n'
                      '  a=data,\n'
                      '  b=data\n'
                      ' in c)')
     # pyformat: enable
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(
         structural_string, '                 Block\n'
         '                /     \\\n'
         '[a=data, b=data]       Ref(c)')
예제 #14
0
 def test_returns_string_for_comp_with_right_overhang(self):
     ref = computation_building_blocks.Reference('a', tf.int32)
     data = computation_building_blocks.Data('data', tf.int32)
     tup = computation_building_blocks.Tuple([ref, data, data, data, data])
     sel = computation_building_blocks.Selection(tup, index=0)
     fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                             sel)
     comp = computation_building_blocks.Call(fn, data)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string,
                      '(a -> <a,data,data,data,data>[0])(data)')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     # pyformat: disable
     self.assertEqual(
         formatted_string, '(a -> <\n'
         '  a,\n'
         '  data,\n'
         '  data,\n'
         '  data,\n'
         '  data\n'
         '>[0])(data)')
     # pyformat: enable
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(
         structural_string, '          Call\n'
         '         /    \\\n'
         'Lambda(a)      data\n'
         '|\n'
         'Sel(0)\n'
         '|\n'
         'Tuple\n'
         '|\n'
         '[Ref(a), data, data, data, data]')