def test_update_struct_namedtuple(self): my_tuple_type = collections.namedtuple('my_tuple_type', 'a b c') state = my_tuple_type(1, 2, 3) state2 = structure.update_struct(state, c=7) self.assertEqual(state2, my_tuple_type(1, 2, 7)) state3 = structure.update_struct(state2, a=8) self.assertEqual(state3, my_tuple_type(8, 2, 7))
def test_update_struct_ordereddict(self): state = collections.OrderedDict([('a', 1), ('b', 2), ('c', 3)]) state2 = structure.update_struct(state, c=7) self.assertEqual(state2, collections.OrderedDict([('a', 1), ('b', 2), ('c', 7)])) state3 = structure.update_struct(state2, a=8) self.assertEqual(state3, collections.OrderedDict([('a', 8), ('b', 2), ('c', 7)]))
def test_update_struct_attrs(self): @attr.s class TestAttrsClass(object): a = attr.ib() b = attr.ib() c = attr.ib() state = TestAttrsClass(1, 2, 3) state2 = structure.update_struct(state, c=7) self.assertEqual(state2, TestAttrsClass(1, 2, 7)) state3 = structure.update_struct(state2, a=8) self.assertEqual(state3, TestAttrsClass(8, 2, 7))
def test_update_struct_fails(self): with self.assertRaisesRegex(TypeError, '`structure` must be a structure'): structure.update_struct((1, 2, 3), a=8) with self.assertRaisesRegex(TypeError, '`structure` must be a structure'): structure.update_struct([1, 2, 3], a=8) with self.assertRaisesRegex(KeyError, 'does not contain a field'): structure.update_struct({'z': 1}, a=8)
def server_update(server_state, weights_delta, aggregator_state, broadcaster_state): """Updates the `server_state` based on `weights_delta`. Args: server_state: A `tff.learning.framework.ServerState`, the state to be updated. weights_delta: The model delta in global trainable variables from clients. aggregator_state: The state of the aggregator after performing aggregation. broadcaster_state: The state of the broadcaster after broadcasting. Returns: The updated `tff.learning.framework.ServerState`. """ with tf.init_scope(): model = model_fn() global_model_weights = reconstruction_utils.get_global_variables(model) optimizer = keras_optimizer.build_or_verify_tff_optimizer( server_optimizer_fn, global_model_weights.trainable, disjoint_init_and_next=True) optimizer_state = server_state.optimizer_state # Initialize the model with the current state. tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights, server_state.model) weights_delta, has_non_finite_weight = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) # We ignore the update if the weights_delta is non finite. if tf.equal(has_non_finite_weight, 0): negative_weights_delta = tf.nest.map_structure( lambda w: -1.0 * w, weights_delta) optimizer_state, updated_weights = optimizer.next( optimizer_state, global_model_weights.trainable, negative_weights_delta) if not isinstance(optimizer, keras_optimizer.KerasOptimizer): # Keras optimizer mutates model variables within the `next` step. tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights.trainable, updated_weights) # Create a new state based on the updated model. return structure.update_struct( server_state, model=global_model_weights, optimizer_state=optimizer_state, model_broadcast_state=broadcaster_state, delta_aggregate_state=aggregator_state, )
def test_update_struct_dict(self): state = collections.OrderedDict([('a', 1), ('b', 2), ('c', 3)]) state2 = structure.update_struct(state, c=7) self.assertEqual(state2, {'a': 1, 'b': 2, 'c': 7}) state3 = structure.update_struct(state2, a=8) self.assertEqual(state3, {'a': 8, 'b': 2, 'c': 7})
def test_update_struct(self): with self.subTest('fully_named'): state = structure.Struct([('a', 1), ('b', 2), ('c', 3)]) state = structure.update_struct(state, c=7) self.assertEqual(state, structure.Struct([('a', 1), ('b', 2), ('c', 7)])) state = structure.update_struct(state, a=8) self.assertEqual(state, structure.Struct([('a', 8), ('b', 2), ('c', 7)])) with self.subTest('partially_named'): state = structure.Struct([(None, 1), ('b', 2), (None, 3)]) state = structure.update_struct(state, b=7) self.assertEqual(state, structure.Struct([(None, 1), ('b', 7), (None, 3)])) with self.assertRaises(KeyError): structure.update_struct(state, a=8) with self.subTest('nested'): state = structure.Struct([('a', {'a1': 1, 'a2': 2}), ('b', 2), ('c', 3)]) state = structure.update_struct(state, a=7) self.assertEqual(state, structure.Struct([('a', 7), ('b', 2), ('c', 3)])) state = structure.update_struct(state, a={'foo': 1, 'bar': 2}) self.assertEqual( state, structure.Struct([('a', { 'foo': 1, 'bar': 2 }), ('b', 2), ('c', 3)])) with self.subTest('unnamed'): state = structure.Struct((None, i) for i in range(3)) with self.assertRaises(KeyError): structure.update_struct(state, a=1) with self.assertRaises(KeyError): structure.update_struct(state, b=1)
def test_update_struct_on_dict_does_not_mutate_original(self): state = collections.OrderedDict(a=1, b=2, c=3) state2 = structure.update_struct(state, c=7) del state2 self.assertEqual(state, collections.OrderedDict(a=1, b=2, c=3))
def test_update_struct(self): with self.subTest('fully_named'): state = structure.Struct.named(a=1, b=2, c=3) state = structure.update_struct(state, c=7) self.assertEqual(state, structure.Struct.named(a=1, b=2, c=7)) state = structure.update_struct(state, a=8) self.assertEqual(state, structure.Struct.named(a=8, b=2, c=7)) with self.subTest('partially_named'): state = structure.Struct([(None, 1), ('b', 2), (None, 3)]) state = structure.update_struct(state, b=7) self.assertEqual(state, structure.Struct([(None, 1), ('b', 7), (None, 3)])) with self.assertRaises(KeyError): structure.update_struct(state, a=8) with self.subTest('nested'): state = structure.Struct.named(a=dict(a1=1, a2=2), b=2, c=3) state = structure.update_struct(state, a=7) self.assertEqual(state, structure.Struct.named(a=7, b=2, c=3)) state = structure.update_struct(state, a=dict(foo=1, bar=2)) self.assertEqual(state, structure.Struct.named(a=dict(foo=1, bar=2), b=2, c=3)) with self.subTest('unnamed'): state = structure.Struct.unnamed(*tuple(range(3))) with self.assertRaises(KeyError): structure.update_struct(state, a=1) with self.assertRaises(KeyError): structure.update_struct(state, b=1)