def test_multiple_named_and_unnamed(self): v = [(None, 10), ('foo', 20), ('bar', 30)] x = structure.Struct(v) self.assertLen(x, 3) self.assertEqual(x[0], 10) self.assertEqual(x[1], 20) self.assertEqual(x[2], 30) self.assertRaises(IndexError, lambda _: x[3], None) self.assertEqual(list(iter(x)), [10, 20, 30]) self.assertEqual(dir(x), ['bar', 'foo']) self.assertEqual(structure.name_list(x), ['foo', 'bar']) self.assertEqual(x.foo, 20) self.assertEqual(x.bar, 30) self.assertRaises(AttributeError, lambda _: x.baz, None) self.assertEqual(x, structure.Struct([(None, 10), ('foo', 20), ('bar', 30)])) self.assertNotEqual( x, structure.Struct([('foo', 10), ('bar', 20), (None, 30)])) self.assertEqual(structure.to_elements(x), v) self.assertEqual( repr(x), 'Struct([(None, 10), (\'foo\', 20), (\'bar\', 30)])') self.assertEqual(str(x), '<10,foo=20,bar=30>') with self.assertRaisesRegex(ValueError, 'unnamed'): structure.to_odict(x) with self.assertRaisesRegex(ValueError, 'named and unnamed'): structure.to_odict_or_tuple(x)
def _get_accumulator_type(member_type): """Constructs a `tff.Type` for the accumulator in sample aggregation. Args: member_type: A `tff.Type` representing the member components of the federated type. Returns: The `tff.StructType` associated with the accumulator. The tuple contains two parts, `accumulators` and `rands`, that are parallel lists (e.g. the i-th index in one corresponds to the i-th index in the other). These two lists are used to sample from the accumulators with equal probability. """ # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed. if member_type.is_struct(): a = structure.map_structure( lambda v: computation_types.TensorType(v.dtype, [None] + v.shape.dims), member_type) return computation_types.StructType( collections.OrderedDict({ 'accumulators': computation_types.StructType(structure.to_odict(a, True)), 'rands': computation_types.TensorType(tf.float32, shape=[None]), })) return computation_types.StructType( collections.OrderedDict({ 'accumulators': computation_types.TensorType( member_type.dtype, shape=[None] + member_type.shape.dims), 'rands': computation_types.TensorType(tf.float32, shape=[None]), }))
def test_single_unnamed(self): v = [(None, 10)] x = structure.Struct(v) self.assertLen(x, 1) self.assertRaises(IndexError, lambda _: x[1], None) self.assertEqual(x[0], 10) self.assertEqual(list(iter(x)), [10]) self.assertEqual(dir(x), []) self.assertRaises(AttributeError, lambda _: x.foo, None) self.assertNotEqual(x, structure.Struct([])) self.assertNotEqual(x, structure.Struct([('foo', 10)])) self.assertEqual(x, structure.Struct([(None, 10)])) self.assertNotEqual(x, structure.Struct([(None, 10), ('foo', 20)])) self.assertEqual(structure.to_elements(x), v) self.assertEqual(repr(x), 'Struct([(None, 10)])') self.assertEqual(str(x), '<10>') with self.assertRaisesRegex(ValueError, 'unnamed'): structure.to_odict(x)
def accumlator_type_fn(): """Gets the type for the accumulators.""" # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed. if member_type.is_struct(): a = structure.map_structure( lambda v: tf.zeros([0] + v.shape.dims, v.dtype), member_type) return _Samples(structure.to_odict(a, True), tf.zeros([0], tf.float32)) if member_type.shape: s = [0] + member_type.shape.dims return _Samples(tf.zeros(s, member_type.dtype), tf.zeros([0], tf.float32))
def test_federated_sum_named_tuples(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType([('a', tf.int32), ('b', tf.float32)], placement_literals.CLIENTS)) def foo(x): return bodies[intrinsic_defs.FEDERATED_SUM.uri](x) self.assertEqual( str(foo.type_signature), '({<a=int32,b=float32>}@CLIENTS -> <a=int32,b=float32>@SERVER)') self.assertDictEqual(structure.to_odict(foo([[1, 2.]])), {'a': 1, 'b': 2.}) self.assertDictEqual( structure.to_odict(foo([[1, 2.], [3, 4.]])), { 'a': 4, 'b': 6. })
def test_empty(self): v = [] x = structure.Struct(v) # Explicitly test the implementation of __len__() here so use, assertLen() # instead of assertEmpty(). self.assertLen(x, 0) # pylint: disable=g-generic-assert self.assertRaises(IndexError, lambda _: x[0], None) self.assertEqual(list(iter(x)), []) self.assertEqual(dir(x), []) self.assertRaises(AttributeError, lambda _: x.foo, None) self.assertEqual(x, structure.Struct([])) self.assertNotEqual(x, structure.Struct([('foo', 10)])) self.assertEqual(structure.to_elements(x), v) self.assertEqual(structure.to_odict(x), collections.OrderedDict()) self.assertEqual(repr(x), 'Struct([])') self.assertEqual(str(x), '<>')
def test_single_named(self): v = [('foo', 20)] x = structure.Struct(v) self.assertLen(x, 1) self.assertEqual(x[0], 20) self.assertRaises(IndexError, lambda _: x[1], None) self.assertEqual(list(iter(x)), [20]) self.assertEqual(dir(x), ['foo']) self.assertEqual(x.foo, 20) self.assertRaises(AttributeError, lambda _: x.bar, None) self.assertNotEqual(x, structure.Struct([])) self.assertNotEqual(x, structure.Struct([('foo', 10)])) self.assertNotEqual(x, structure.Struct([(None, 20)])) self.assertEqual(x, structure.Struct([('foo', 20)])) self.assertNotEqual(x, structure.Struct([('foo', 20), ('bar', 30)])) self.assertEqual(structure.to_elements(x), v) self.assertEqual(repr(x), 'Struct([(\'foo\', 20)])') self.assertEqual(str(x), '<foo=20>') self.assertEqual(structure.to_odict(x), collections.OrderedDict(v))