def test_raises_type_error(self): with self.assertRaises(TypeError): computation_building_blocks.compact_representation(None) with self.assertRaises(TypeError): computation_building_blocks.formatted_representation(None) with self.assertRaises(TypeError): computation_building_blocks.structural_representation(None)
def test_basic_functionality_of_tuple_class(self): x = computation_building_blocks.Reference('foo', tf.int32) y = computation_building_blocks.Reference('bar', tf.bool) z = computation_building_blocks.Tuple([x, ('y', y)]) with self.assertRaises(ValueError): _ = computation_building_blocks.Tuple([('', y)]) self.assertIsInstance(z, anonymous_tuple.AnonymousTuple) self.assertEqual(str(z.type_signature), '<int32,y=bool>') self.assertEqual( repr(z), 'Tuple([(None, Reference(\'foo\', TensorType(tf.int32))), (\'y\', ' 'Reference(\'bar\', TensorType(tf.bool)))])') self.assertEqual(computation_building_blocks.compact_representation(z), '<foo,y=bar>') self.assertEqual(dir(z), ['y']) self.assertIs(z.y, y) self.assertLen(z, 2) self.assertIs(z[0], x) self.assertIs(z[1], y) self.assertEqual( ','.join( computation_building_blocks.compact_representation(e) for e in iter(z)), 'foo,bar') z_proto = z.proto self.assertEqual(type_serialization.deserialize_type(z_proto.type), z.type_signature) self.assertEqual(z_proto.WhichOneof('computation'), 'tuple') self.assertEqual([e.name for e in z_proto.tuple.element], ['', 'y']) self._serialize_deserialize_roundtrip_test(z)
def _serialize_deserialize_roundtrip_test(self, target): """Performs roundtrip serialization/deserialization of the given target. Args: target: An instane of ComputationBuildingBlock to serialize-deserialize. """ assert isinstance(target, computation_building_blocks.ComputationBuildingBlock) proto = target.proto target2 = computation_building_blocks.ComputationBuildingBlock.from_proto( proto) proto2 = target2.proto self.assertEqual( computation_building_blocks.compact_representation(target), computation_building_blocks.compact_representation(target2)) self.assertEqual(str(proto), str(proto2))
def test_intrinsic_class_succeeds_simple_federated_map(self): simple_function = computation_types.FunctionType(tf.int32, tf.float32) federated_arg = computation_types.FederatedType( simple_function.parameter, placements.CLIENTS) federated_result = computation_types.FederatedType( simple_function.result, placements.CLIENTS) federated_map_concrete_type = computation_types.FunctionType( [simple_function, federated_arg], federated_result) concrete_federated_map = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, federated_map_concrete_type) self.assertIsInstance(concrete_federated_map, computation_building_blocks.Intrinsic) self.assertEqual( str(concrete_federated_map.type_signature), '(<(int32 -> float32),{int32}@CLIENTS> -> {float32}@CLIENTS)') self.assertEqual(concrete_federated_map.uri, 'federated_map') self.assertEqual( computation_building_blocks.compact_representation( concrete_federated_map), 'federated_map') concrete_federated_map_proto = concrete_federated_map.proto self.assertEqual( type_serialization.deserialize_type( concrete_federated_map_proto.type), concrete_federated_map.type_signature) self.assertEqual( concrete_federated_map_proto.WhichOneof('computation'), 'intrinsic') self.assertEqual(concrete_federated_map_proto.intrinsic.uri, concrete_federated_map.uri) self._serialize_deserialize_roundtrip_test(concrete_federated_map)
def test_returns_string_for_federated_aggregate(self): comp = computation_test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c') compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual( compact_string, 'federated_aggregate(<data,data,(a -> data),(b -> data),(c -> data)>)' ) formatted_string = computation_building_blocks.formatted_representation( comp) # pyformat: disable self.assertEqual( formatted_string, 'federated_aggregate(<\n' ' data,\n' ' data,\n' ' (a -> data),\n' ' (b -> data),\n' ' (c -> data)\n' '>)') # pyformat: enable structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'federated_aggregate Tuple\n' ' |\n' ' [data, data, Lambda(a), Lambda(b), Lambda(c)]\n' ' | | |\n' ' data data data')
def test_returns_string_for_federated_map(self): comp = computation_test_utils.create_dummy_called_federated_map( parameter_name='a') compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'federated_map(<(a -> a),data>)') formatted_string = computation_building_blocks.formatted_representation( comp) # pyformat: disable self.assertEqual(formatted_string, 'federated_map(<\n' ' (a -> a),\n' ' data\n' '>)') # pyformat: enable structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'federated_map Tuple\n' ' |\n' ' [Lambda(a), data]\n' ' |\n' ' Ref(a)')
def test_returns_string_for_comp_with_left_overhang(self): fn_type = computation_types.FunctionType(tf.int32, tf.int32) fn = computation_building_blocks.Reference('a', fn_type) proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(1), None, context_stack_impl.context_stack) compiled = computation_building_blocks.CompiledComputation( proto, 'bbbbb') arg = computation_building_blocks.Call(compiled) comp = computation_building_blocks.Call(fn, arg) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'a(comp#bbbbb())') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'a(comp#bbbbb())') structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' ' Ref(a) Call\n' ' /\n' 'Compiled(bbbbb)')
def test_basic_functionality_of_call_class(self): x = computation_building_blocks.Reference( 'foo', computation_types.FunctionType(tf.int32, tf.bool)) y = computation_building_blocks.Reference('bar', tf.int32) z = computation_building_blocks.Call(x, y) self.assertEqual(str(z.type_signature), 'bool') self.assertIs(z.function, x) self.assertIs(z.argument, y) self.assertEqual( repr(z), 'Call(Reference(\'foo\', ' 'FunctionType(TensorType(tf.int32), TensorType(tf.bool))), ' 'Reference(\'bar\', TensorType(tf.int32)))') self.assertEqual(computation_building_blocks.compact_representation(z), 'foo(bar)') with self.assertRaises(TypeError): computation_building_blocks.Call(x) w = computation_building_blocks.Reference('bak', tf.float32) with self.assertRaises(TypeError): computation_building_blocks.Call(x, w) z_proto = z.proto self.assertEqual(type_serialization.deserialize_type(z_proto.type), z.type_signature) self.assertEqual(z_proto.WhichOneof('computation'), 'call') self.assertEqual(str(z_proto.call.function), str(x.proto)) self.assertEqual(str(z_proto.call.argument), str(y.proto)) self._serialize_deserialize_roundtrip_test(z)
def test_basic_functionality_of_selection_class(self): x = computation_building_blocks.Reference('foo', [('bar', tf.int32), ('baz', tf.bool)]) y = computation_building_blocks.Selection(x, name='bar') self.assertEqual(y.name, 'bar') self.assertEqual(y.index, None) self.assertEqual(str(y.type_signature), 'int32') self.assertEqual( repr(y), 'Selection(Reference(\'foo\', NamedTupleType([' '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))' ', name=\'bar\')') self.assertEqual(computation_building_blocks.compact_representation(y), 'foo.bar') z = computation_building_blocks.Selection(x, name='baz') self.assertEqual(str(z.type_signature), 'bool') self.assertEqual(computation_building_blocks.compact_representation(z), 'foo.baz') with self.assertRaises(ValueError): _ = computation_building_blocks.Selection(x, name='bak') x0 = computation_building_blocks.Selection(x, index=0) self.assertEqual(x0.name, None) self.assertEqual(x0.index, 0) self.assertEqual(str(x0.type_signature), 'int32') self.assertEqual( repr(x0), 'Selection(Reference(\'foo\', NamedTupleType([' '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))' ', index=0)') self.assertEqual( computation_building_blocks.compact_representation(x0), 'foo[0]') x1 = computation_building_blocks.Selection(x, index=1) self.assertEqual(str(x1.type_signature), 'bool') self.assertEqual( computation_building_blocks.compact_representation(x1), 'foo[1]') with self.assertRaises(ValueError): _ = computation_building_blocks.Selection(x, index=2) with self.assertRaises(ValueError): _ = computation_building_blocks.Selection(x, index=-1) y_proto = y.proto self.assertEqual(type_serialization.deserialize_type(y_proto.type), y.type_signature) self.assertEqual(y_proto.WhichOneof('computation'), 'selection') self.assertEqual(str(y_proto.selection.source), str(x.proto)) self.assertEqual(y_proto.selection.name, 'bar') self._serialize_deserialize_roundtrip_test(y) self._serialize_deserialize_roundtrip_test(z) self._serialize_deserialize_roundtrip_test(x0) self._serialize_deserialize_roundtrip_test(x1)
def _check_whitelisted(comp): if isinstance(comp, computation_building_blocks.Intrinsic ) and comp.uri not in uri_whitelist: raise ValueError( 'Encountered an Intrinsic not currently reducible to aggregate or ' 'broadcast, the intrinsic {}'.format( computation_building_blocks.compact_representation(comp))) return comp, False
def test_basic_functionality_of_compiled_computation_class(self): comp, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda x: x + 3, tf.int32, context_stack_impl.context_stack) x = computation_building_blocks.CompiledComputation(comp) self.assertEqual(str(x.type_signature), '(int32 -> int32)') self.assertEqual(str(x.proto), str(comp)) self.assertTrue( re.match( r'CompiledComputation\([0-9a-f]+, ' r'FunctionType\(TensorType\(tf\.int32\), ' r'TensorType\(tf\.int32\)\)\)', repr(x))) self.assertTrue( re.match(r'comp#[0-9a-f]+', computation_building_blocks.compact_representation(x))) y = computation_building_blocks.CompiledComputation(comp, name='foo') self.assertEqual(computation_building_blocks.compact_representation(y), 'comp#foo') self._serialize_deserialize_roundtrip_test(x)
def _check_single_placement(comp): """Checks that the placement in `type_spec` matches `single_placement`.""" if (isinstance(comp.type_signature, computation_types.FederatedType) and comp.type_signature.placement != single_placement): raise ValueError( 'Comp contains a placement other than {}; ' 'placement {} on comp {} inside the structure. '.format( single_placement, comp.type_signature.placement, computation_building_blocks.compact_representation(comp))) return comp, False
def test_returns_string_for_intrinsic(self): comp = computation_building_blocks.Intrinsic('intrinsic', tf.int32) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'intrinsic') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'intrinsic') structural_string = computation_building_blocks.structural_representation( comp) self.assertEqual(structural_string, 'intrinsic')
def test_returns_string_for_reference(self): comp = computation_building_blocks.Reference('a', tf.int32) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'a') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'a') structural_string = computation_building_blocks.structural_representation( comp) self.assertEqual(structural_string, 'Ref(a)')
def test_returns_string_for_placement(self): comp = computation_building_blocks.Placement(placements.CLIENTS) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'CLIENTS') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'CLIENTS') structural_string = computation_building_blocks.structural_representation( comp) self.assertEqual(structural_string, 'Placement')
def test_raises_with_federated_mean(self): intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MEAN.uri, computation_types.FunctionType( computation_types.FederatedType(tf.int32, placements.CLIENTS), computation_types.FederatedType(tf.int32, placements.SERVER))) with self.assertRaisesRegex( ValueError, computation_building_blocks.compact_representation(intrinsic)): tree_analysis.check_intrinsics_whitelisted_for_reduction(intrinsic)
def test_returns_string_for_compiled_computation(self): proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(1), None, context_stack_impl.context_stack) comp = computation_building_blocks.CompiledComputation(proto, 'a') compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'comp#a') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'comp#a') structural_string = computation_building_blocks.structural_representation( comp) self.assertEqual(structural_string, 'Compiled(a)')
def test_basic_functionality_of_reference_class(self): x = computation_building_blocks.Reference('foo', tf.int32) self.assertEqual(x.name, 'foo') self.assertEqual(str(x.type_signature), 'int32') self.assertEqual(repr(x), 'Reference(\'foo\', TensorType(tf.int32))') self.assertEqual(computation_building_blocks.compact_representation(x), 'foo') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'reference') self.assertEqual(x_proto.reference.name, x.name) self._serialize_deserialize_roundtrip_test(x)
def __eq__(self, other): """Base class equality checks names and values equal.""" # TODO(b/130890785): Delegate value-checking to # `computation_building_blocks.ComputationBuildingBlock`. if self is other: return True if not isinstance(other, BoundVariableTracker): return NotImplemented if self.name != other.name: return False if (isinstance(self.value, computation_building_blocks.ComputationBuildingBlock) and isinstance( other.value, computation_building_blocks.ComputationBuildingBlock)): return (computation_building_blocks.compact_representation( self.value) == computation_building_blocks.compact_representation( other.value) and type_utils.are_equivalent_types( self.value.type_signature, other.value.type_signature)) return self.value is other.value
def test_basic_functionality_of_placement_class(self): x = computation_building_blocks.Placement(placements.CLIENTS) self.assertEqual(str(x.type_signature), 'placement') self.assertEqual(x.uri, 'clients') self.assertEqual(repr(x), 'Placement(\'clients\')') self.assertEqual(computation_building_blocks.compact_representation(x), 'CLIENTS') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'placement') self.assertEqual(x_proto.placement.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
def test_basic_functionality_of_lambda_class(self): arg_name = 'arg' arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)), ('x', tf.int32)] arg = computation_building_blocks.Reference(arg_name, arg_type) arg_f = computation_building_blocks.Selection(arg, name='f') arg_x = computation_building_blocks.Selection(arg, name='x') x = computation_building_blocks.Lambda( arg_name, arg_type, computation_building_blocks.Call( arg_f, computation_building_blocks.Call(arg_f, arg_x))) self.assertEqual(str(x.type_signature), '(<f=(int32 -> int32),x=int32> -> int32)') self.assertEqual(x.parameter_name, arg_name) self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>') self.assertEqual( computation_building_blocks.compact_representation(x.result), 'arg.f(arg.f(arg.x))') arg_type_repr = ( 'NamedTupleType([' '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), ' '(\'x\', TensorType(tf.int32))])') self.assertEqual( repr(x), 'Lambda(\'arg\', {0}, ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format( arg_type_repr)) self.assertEqual(computation_building_blocks.compact_representation(x), '(arg -> arg.f(arg.f(arg.x)))') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'lambda') self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name) self.assertEqual(str(getattr(x_proto, 'lambda').result), str(x.result.proto)) self._serialize_deserialize_roundtrip_test(x)
def test_basic_intrinsic_functionality_plus_canonical_typecheck(self): x = computation_building_blocks.Intrinsic( 'generic_plus', computation_types.FunctionType([tf.int32, tf.int32], tf.int32)) self.assertEqual(str(x.type_signature), '(<int32,int32> -> int32)') self.assertEqual(x.uri, 'generic_plus') self.assertEqual(computation_building_blocks.compact_representation(x), 'generic_plus') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic') self.assertEqual(x_proto.intrinsic.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
def test_returns_string_for_selection_with_index(self): ref = computation_building_blocks.Reference('a', (('b', tf.int32), ('c', tf.bool))) comp = computation_building_blocks.Selection(ref, index=0) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'a[0]') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'a[0]') structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual(structural_string, 'Sel(0)\n' '|\n' 'Ref(a)')
def test_returns_string_for_lambda(self): ref = computation_building_blocks.Reference('a', tf.int32) comp = computation_building_blocks.Lambda(ref.name, ref.type_signature, ref) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, '(a -> a)') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, '(a -> a)') structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual(structural_string, 'Lambda(a)\n' '|\n' 'Ref(a)')
def test_returns_string_for_tuple_with_no_names(self): data = computation_building_blocks.Data('data', tf.int32) comp = computation_building_blocks.Tuple((data, data)) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, '<data,data>') formatted_string = computation_building_blocks.formatted_representation( comp) # pyformat: disable self.assertEqual(formatted_string, '<\n' ' data,\n' ' data\n' '>') # pyformat: enable structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual(structural_string, 'Tuple\n' '|\n' '[data, data]')
def test_basic_functionality_of_block_class(self): x = computation_building_blocks.Block([ ('x', computation_building_blocks.Reference('arg', (tf.int32, tf.int32))), ('y', computation_building_blocks.Selection( computation_building_blocks.Reference('x', (tf.int32, tf.int32)), index=0)) ], computation_building_blocks.Reference('y', tf.int32)) self.assertEqual(str(x.type_signature), 'int32') self.assertEqual( [(k, computation_building_blocks.compact_representation(v)) for k, v in x.locals], [('x', 'arg'), ('y', 'x[0]')]) self.assertEqual( computation_building_blocks.compact_representation(x.result), 'y') self.assertEqual( repr(x), 'Block([(\'x\', Reference(\'arg\', ' 'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)]))), ' '(\'y\', Selection(Reference(\'x\', ' 'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)])), ' 'index=0))], ' 'Reference(\'y\', TensorType(tf.int32)))') self.assertEqual(computation_building_blocks.compact_representation(x), '(let x=arg,y=x[0] in y)') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'block') self.assertEqual(str(x_proto.block.result), str(x.result.proto)) for idx, loc_proto in enumerate(x_proto.block.local): loc_name, loc_value = x.locals[idx] self.assertEqual(loc_proto.name, loc_name) self.assertEqual(str(loc_proto.value), str(loc_value.proto)) self._serialize_deserialize_roundtrip_test(x)
def test_basic_functionality_of_intrinsic_class(self): x = computation_building_blocks.Intrinsic( 'add_one', computation_types.FunctionType(tf.int32, tf.int32)) self.assertEqual(str(x.type_signature), '(int32 -> int32)') self.assertEqual(x.uri, 'add_one') self.assertEqual( repr(x), 'Intrinsic(\'add_one\', ' 'FunctionType(TensorType(tf.int32), TensorType(tf.int32)))') self.assertEqual(computation_building_blocks.compact_representation(x), 'add_one') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic') self.assertEqual(x_proto.intrinsic.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
def test_basic_functionality_of_data_class(self): x = computation_building_blocks.Data( '/tmp/mydata', computation_types.SequenceType(tf.int32)) self.assertEqual(str(x.type_signature), 'int32*') self.assertEqual(x.uri, '/tmp/mydata') self.assertEqual( repr(x), 'Data(\'/tmp/mydata\', SequenceType(TensorType(tf.int32)))') self.assertEqual(computation_building_blocks.compact_representation(x), '/tmp/mydata') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'data') self.assertEqual(x_proto.data.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
def test_get_curried(self): add_numbers = value_impl.ValueImpl( computation_building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto( computations.tf_computation(tf.add, [tf.int32, tf.int32]))), _context_stack) curried = value_utils.get_curried(add_numbers) self.assertEqual(str(curried.type_signature), '(int32 -> (int32 -> int32))') comp, _ = transformations.uniquify_compiled_computation_names( value_impl.ValueImpl.get_comp(curried)) self.assertEqual( computation_building_blocks.compact_representation(comp), '(arg0 -> (arg1 -> comp#1(<arg0,arg1>)))')
def test_returns_string_for_call_with_arg(self): ref = computation_building_blocks.Reference('a', tf.int32) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, ref) arg = computation_building_blocks.Data('data', tf.int32) comp = computation_building_blocks.Call(fn, arg) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, '(a -> a)(data)') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, '(a -> a)(data)') structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'Lambda(a) data\n' '|\n' 'Ref(a)')