def test_raises_type_error(self): with self.assertRaises(TypeError): computation_building_blocks.compact_representation(None) with self.assertRaises(TypeError): computation_building_blocks.formatted_representation(None) with self.assertRaises(TypeError): computation_building_blocks.structural_representation(None)
def test_returns_string_for_comp_with_left_overhang(self): fn_type = computation_types.FunctionType(tf.int32, tf.int32) fn = computation_building_blocks.Reference('a', fn_type) proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(1), None, context_stack_impl.context_stack) compiled = computation_building_blocks.CompiledComputation( proto, 'bbbbb') arg = computation_building_blocks.Call(compiled) comp = computation_building_blocks.Call(fn, arg) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'a(comp#bbbbb())') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'a(comp#bbbbb())') structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' ' Ref(a) Call\n' ' /\n' 'Compiled(bbbbb)')
def test_returns_string_for_federated_map(self): comp = computation_test_utils.create_dummy_called_federated_map( parameter_name='a') compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'federated_map(<(a -> a),data>)') formatted_string = computation_building_blocks.formatted_representation( comp) # pyformat: disable self.assertEqual(formatted_string, 'federated_map(<\n' ' (a -> a),\n' ' data\n' '>)') # pyformat: enable structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'federated_map Tuple\n' ' |\n' ' [Lambda(a), data]\n' ' |\n' ' Ref(a)')
def test_returns_string_for_federated_aggregate(self): comp = computation_test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c') compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual( compact_string, 'federated_aggregate(<data,data,(a -> data),(b -> data),(c -> data)>)' ) formatted_string = computation_building_blocks.formatted_representation( comp) # pyformat: disable self.assertEqual( formatted_string, 'federated_aggregate(<\n' ' data,\n' ' data,\n' ' (a -> data),\n' ' (b -> data),\n' ' (c -> data)\n' '>)') # pyformat: enable structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'federated_aggregate Tuple\n' ' |\n' ' [data, data, Lambda(a), Lambda(b), Lambda(c)]\n' ' | | |\n' ' data data data')
def test_returns_string_for_reference(self): comp = computation_building_blocks.Reference('a', tf.int32) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'a') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'a') structural_string = computation_building_blocks.structural_representation( comp) self.assertEqual(structural_string, 'Ref(a)')
def test_returns_string_for_placement(self): comp = computation_building_blocks.Placement(placements.CLIENTS) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'CLIENTS') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'CLIENTS') structural_string = computation_building_blocks.structural_representation( comp) self.assertEqual(structural_string, 'Placement')
def test_returns_string_for_intrinsic(self): comp = computation_building_blocks.Intrinsic('intrinsic', tf.int32) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'intrinsic') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'intrinsic') structural_string = computation_building_blocks.structural_representation( comp) self.assertEqual(structural_string, 'intrinsic')
def test_returns_string_for_compiled_computation(self): proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(1), None, context_stack_impl.context_stack) comp = computation_building_blocks.CompiledComputation(proto, 'a') compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'comp#a') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'comp#a') structural_string = computation_building_blocks.structural_representation( comp) self.assertEqual(structural_string, 'Compiled(a)')
def test_returns_string_for_selection_with_index(self): ref = computation_building_blocks.Reference('a', (('b', tf.int32), ('c', tf.bool))) comp = computation_building_blocks.Selection(ref, index=0) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'a[0]') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'a[0]') structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual(structural_string, 'Sel(0)\n' '|\n' 'Ref(a)')
def test_returns_string_for_lambda(self): ref = computation_building_blocks.Reference('a', tf.int32) comp = computation_building_blocks.Lambda(ref.name, ref.type_signature, ref) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, '(a -> a)') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, '(a -> a)') structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual(structural_string, 'Lambda(a)\n' '|\n' 'Ref(a)')
def test_returns_string_for_tuple_with_no_names(self): data = computation_building_blocks.Data('data', tf.int32) comp = computation_building_blocks.Tuple((data, data)) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, '<data,data>') formatted_string = computation_building_blocks.formatted_representation( comp) # pyformat: disable self.assertEqual(formatted_string, '<\n' ' data,\n' ' data\n' '>') # pyformat: enable structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual(structural_string, 'Tuple\n' '|\n' '[data, data]')
def test_returns_string_for_call_with_arg(self): ref = computation_building_blocks.Reference('a', tf.int32) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, ref) arg = computation_building_blocks.Data('data', tf.int32) comp = computation_building_blocks.Call(fn, arg) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, '(a -> a)(data)') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, '(a -> a)(data)') structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'Lambda(a) data\n' '|\n' 'Ref(a)')
def test_returns_string_for_block(self): data = computation_building_blocks.Data('data', tf.int32) ref = computation_building_blocks.Reference('c', tf.int32) comp = computation_building_blocks.Block((('a', data), ('b', data)), ref) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, '(let a=data,b=data in c)') formatted_string = computation_building_blocks.formatted_representation( comp) # pyformat: disable self.assertEqual(formatted_string, '(let\n' ' a=data,\n' ' b=data\n' ' in c)') # pyformat: enable structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual( structural_string, ' Block\n' ' / \\\n' '[a=data, b=data] Ref(c)')
def test_returns_string_for_comp_with_right_overhang(self): ref = computation_building_blocks.Reference('a', tf.int32) data = computation_building_blocks.Data('data', tf.int32) tup = computation_building_blocks.Tuple([ref, data, data, data, data]) sel = computation_building_blocks.Selection(tup, index=0) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, sel) comp = computation_building_blocks.Call(fn, data) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, '(a -> <a,data,data,data,data>[0])(data)') formatted_string = computation_building_blocks.formatted_representation( comp) # pyformat: disable self.assertEqual( formatted_string, '(a -> <\n' ' a,\n' ' data,\n' ' data,\n' ' data,\n' ' data\n' '>[0])(data)') # pyformat: enable structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'Lambda(a) data\n' '|\n' 'Sel(0)\n' '|\n' 'Tuple\n' '|\n' '[Ref(a), data, data, data, data]')