def test_multiple_named_and_unnamed(self): v = [(None, 10), ('foo', 20), ('bar', 30)] x = anonymous_tuple.AnonymousTuple(v) self.assertLen(x, 3) self.assertEqual(x[0], 10) self.assertEqual(x[1], 20) self.assertEqual(x[2], 30) self.assertRaises(IndexError, lambda _: x[3], None) self.assertEqual(list(iter(x)), [10, 20, 30]) self.assertEqual(dir(x), ['bar', 'foo']) self.assertEqual(x.foo, 20) self.assertEqual(x.bar, 30) self.assertRaises(AttributeError, lambda _: x.baz, None) self.assertEqual( x, anonymous_tuple.AnonymousTuple([(None, 10), ('foo', 20), ('bar', 30)])) self.assertNotEqual( x, anonymous_tuple.AnonymousTuple([('foo', 10), ('bar', 20), (None, 30)])) self.assertEqual(anonymous_tuple.to_elements(x), v) self.assertEqual(repr(x), 'AnonymousTuple([(None, 10), (foo, 20), (bar, 30)])') self.assertEqual(str(x), '<10,foo=20,bar=30>') with self.assertRaisesRegex(ValueError, 'unnamed'): anonymous_tuple.to_odict(x)
def _server_state_from_tff_result(self, result): if self._per_vector_clipping: per_vector_aggregate_states = [ anonymous_tuple.to_odict(elt, recursive=True) for _, elt in anonymous_tuple.iter_elements(result.delta_aggregate_state) ] else: per_vector_aggregate_states = anonymous_tuple.to_odict( result.delta_aggregate_state, recursive=True) return tff.learning.framework.ServerState( tff.learning.ModelWeights(tuple(result.model.trainable), tuple(result.model.non_trainable)), list(result.optimizer_state), per_vector_aggregate_states, tuple(result.model_broadcast_state))
def _get_accumulator_type(member_type): """Constructs a `tff.Type` for the accumulator in sample aggregation. Args: member_type: A `tff.Type` representing the member components of the federated type. Returns: The `tff.NamedTupleType` associated with the accumulator. The tuple contains two parts, `accumulators` and `rands`, that are parallel lists (e.g. the i-th index in one corresponds to the i-th index in the other). These two lists are used to sample from the accumulators with equal probability. """ # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed. if isinstance(member_type, tff.NamedTupleType): a = anonymous_tuple.map_structure( lambda v: tff.TensorType(v.dtype, [None] + v.shape.dims), member_type) return tff.NamedTupleType( collections.OrderedDict({ 'accumulators': tff.NamedTupleType(anonymous_tuple.to_odict(a, True)), 'rands': tff.TensorType(tf.float32, shape=[None]) })) return tff.NamedTupleType( collections.OrderedDict({ 'accumulators': tff.TensorType(member_type.dtype, shape=[None] + member_type.shape.dims), 'rands': tff.TensorType(tf.float32, shape=[None]) }))
def _server_state_from_tff_result(self, result): return tff.learning.framework.ServerState( tff.learning.ModelWeights(tuple(result.model.trainable), tuple(result.model.non_trainable)), list(result.optimizer_state), anonymous_tuple.to_odict(result.delta_aggregate_state, True), tuple(result.model_broadcast_state))
def accumlator_type_fn(): # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed. if isinstance(member_type, tff.NamedTupleType): a = anonymous_tuple.map_structure( lambda v: tf.zeros([0] + v.shape.dims, v.dtype), member_type) return _Samples(anonymous_tuple.to_odict(a), tf.zeros([0], tf.float32)) if member_type.shape: s = [0] + member_type.shape.dims return _Samples(tf.zeros(s, member_type.dtype), tf.zeros([0], tf.float32))
def test_single_unnamed(self): v = [(None, 10)] x = anonymous_tuple.AnonymousTuple(v) self.assertLen(x, 1) self.assertRaises(IndexError, lambda _: x[1], None) self.assertEqual(x[0], 10) self.assertEqual(list(iter(x)), [10]) self.assertEqual(dir(x), []) self.assertRaises(AttributeError, lambda _: x.foo, None) self.assertNotEqual(x, anonymous_tuple.AnonymousTuple([])) self.assertNotEqual(x, anonymous_tuple.AnonymousTuple([('foo', 10)])) self.assertEqual(x, anonymous_tuple.AnonymousTuple([(None, 10)])) self.assertNotEqual( x, anonymous_tuple.AnonymousTuple([(None, 10), ('foo', 20)])) self.assertEqual(anonymous_tuple.to_elements(x), v) self.assertEqual(repr(x), 'AnonymousTuple([(None, 10)])') self.assertEqual(str(x), '<10>') with self.assertRaisesRegex(ValueError, 'unnamed'): anonymous_tuple.to_odict(x)
def accumlator_type_fn(): """Gets the type for the accumulators.""" # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed. if member_type.is_struct(): a = anonymous_tuple.map_structure( lambda v: tf.zeros([0] + v.shape.dims, v.dtype), member_type) return _Samples( anonymous_tuple.to_odict(a, True), tf.zeros([0], tf.float32)) if member_type.shape: s = [0] + member_type.shape.dims return _Samples(tf.zeros(s, member_type.dtype), tf.zeros([0], tf.float32))
def test_federated_sum_named_tuples(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType([('a', tf.int32), ('b', tf.float32)], placements.CLIENTS)) def foo(x): return bodies[intrinsic_defs.FEDERATED_SUM.uri](x) self.assertEqual( str(foo.type_signature), '({<a=int32,b=float32>}@CLIENTS -> <a=int32,b=float32>@SERVER)') self.assertDictEqual( anonymous_tuple.to_odict(foo([[1, 2.]])), { 'a': 1, 'b': 2. }) self.assertDictEqual( anonymous_tuple.to_odict(foo([[1, 2.], [3, 4.]])), { 'a': 4, 'b': 6. })
def test_empty(self): v = [] x = anonymous_tuple.AnonymousTuple(v) # Explicitly test the implementation of __len__() here so use, assertLen() # instead of assertEmpty(). self.assertLen(x, 0) # pylint: disable=g-generic-assert self.assertRaises(IndexError, lambda _: x[0], None) self.assertEqual(list(iter(x)), []) self.assertEqual(dir(x), []) self.assertRaises(AttributeError, lambda _: x.foo, None) self.assertEqual(x, anonymous_tuple.AnonymousTuple([])) self.assertNotEqual(x, anonymous_tuple.AnonymousTuple([('foo', 10)])) self.assertEqual(anonymous_tuple.to_elements(x), v) self.assertEqual(anonymous_tuple.to_odict(x), collections.OrderedDict()) self.assertEqual(repr(x), 'AnonymousTuple([])') self.assertEqual(str(x), '<>')
def server_update_model_tf(server_state, model_delta): """Converts args to correct python types and calls server_update_model.""" # We need to convert TFF types to the types server_update_model expects. # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not # pretty, fold this into anonymous_tuple module or get working with # tf.contrib.framework.nest. py_typecheck.check_type(model_delta, anonymous_tuple.AnonymousTuple) model_delta = anonymous_tuple.to_odict(model_delta) py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple) server_state = ServerState( model=model_utils.ModelWeights.from_tff_value(server_state.model), optimizer_state=list(server_state.optimizer_state)) return server_update_model( server_state, model_delta, model_fn=model_fn, optimizer_fn=server_optimizer_fn)
def test_single_named(self): v = [('foo', 20)] x = anonymous_tuple.AnonymousTuple(v) self.assertLen(x, 1) self.assertEqual(x[0], 20) self.assertRaises(IndexError, lambda _: x[1], None) self.assertEqual(list(iter(x)), [20]) self.assertEqual(dir(x), ['foo']) self.assertEqual(x.foo, 20) self.assertRaises(AttributeError, lambda _: x.bar, None) self.assertNotEqual(x, anonymous_tuple.AnonymousTuple([])) self.assertNotEqual(x, anonymous_tuple.AnonymousTuple([('foo', 10)])) self.assertNotEqual(x, anonymous_tuple.AnonymousTuple([(None, 20)])) self.assertEqual(x, anonymous_tuple.AnonymousTuple([('foo', 20)])) self.assertNotEqual( x, anonymous_tuple.AnonymousTuple([('foo', 20), ('bar', 30)])) self.assertEqual(anonymous_tuple.to_elements(x), v) self.assertEqual(repr(x), 'AnonymousTuple([(\'foo\', 20)])') self.assertEqual(str(x), '<foo=20>') self.assertEqual(anonymous_tuple.to_odict(x), collections.OrderedDict(v))
def from_tff_value(cls, anon_tuple): py_typecheck.check_type(anon_tuple, anonymous_tuple.AnonymousTuple) return cls( anonymous_tuple.to_odict(anon_tuple.trainable), anonymous_tuple.to_odict(anon_tuple.non_trainable))