def test_update_state_namedtuple(self): my_tuple_type = collections.namedtuple('my_tuple_type', 'a b c') state = my_tuple_type(1, 2, 3) state2 = computation_utils.update_state(state, c=7) self.assertEqual(state2, my_tuple_type(1, 2, 7)) state3 = computation_utils.update_state(state2, a=8) self.assertEqual(state3, my_tuple_type(8, 2, 7))
def test_update_state_ordereddict(self): state = collections.OrderedDict([('a', 1), ('b', 2), ('c', 3)]) state2 = computation_utils.update_state(state, c=7) self.assertEqual( state2, collections.OrderedDict([('a', 1), ('b', 2), ('c', 7)])) state3 = computation_utils.update_state(state2, a=8) self.assertEqual( state3, collections.OrderedDict([('a', 8), ('b', 2), ('c', 7)]))
def test_update_state(self): MyTuple = collections.namedtuple('MyTuple', 'a b c') # pylint: disable=invalid-name t = MyTuple(1, 2, 3) t2 = computation_utils.update_state(t, c=7) self.assertEqual(t2, MyTuple(1, 2, 7)) t3 = computation_utils.update_state(t2, a=8) self.assertEqual(t3, MyTuple(8, 2, 7)) with six.assertRaisesRegex(self, TypeError, r'state.*namedtuple'): computation_utils.update_state((1, 2, 3), a=8)
def test_update_state_attrs(self): @attr.s class TestAttrsClass(object): a = attr.ib() b = attr.ib() c = attr.ib() state = TestAttrsClass(1, 2, 3) state2 = computation_utils.update_state(state, c=7) self.assertEqual(state2, TestAttrsClass(1, 2, 7)) state3 = computation_utils.update_state(state2, a=8) self.assertEqual(state3, TestAttrsClass(8, 2, 7))
def state_with_new_model_weights( server_state: ServerState, trainable_weights: List[np.ndarray], non_trainable_weights: List[np.ndarray], ) -> ServerState: """Returns a `ServerState` with updated model weights. Args: server_state: a server state object returned by an iterative training process like `tff.learning.build_federated_averaging_process`. trainable_weights: a list of `numpy` arrays in the order of the original model's `trainable_variables`. non_trainable_weights: a list of `numpy` arrays in the order of the original model's `non_trainable_variables`. Returns: A new server `ServerState` object which can be passed to the `next` method of the iterative process. """ py_typecheck.check_type(server_state, ServerState) leaf_types = (int, float, np.ndarray, tf.Tensor) def assert_weight_lists_match(old_value, new_value): """Assert two flat lists of ndarrays or tensors match.""" if isinstance(new_value, leaf_types) and isinstance( old_value, leaf_types): if (old_value.dtype != new_value.dtype or old_value.shape != new_value.shape): raise TypeError('Element is not the same tensor type. old ' f'({old_value.dtype}, {old_value.shape}) != ' f'new ({new_value.dtype}, {new_value.shape})') elif (isinstance(new_value, collections.Sequence) and isinstance(old_value, collections.Sequence)): if len(old_value) != len(new_value): raise TypeError( 'Model weights have different lengths: ' f'(old) {len(old_value)} != (new) {len(new_value)})\n' f'Old values: {old_value}\nNew values: {new_value}') for old, new in zip(old_value, new_value): assert_weight_lists_match(old, new) else: raise TypeError( 'Model weights structures contains types that cannot be ' 'handled.\nOld weights structure: {old}\n' 'New weights structure: {new}\n' 'Must be one of (int, float, np.ndarray, tf.Tensor, ' 'collections.Sequence)'.format( old=tf.nest.map_structure(type, old_value), new=tf.nest.map_structure(type, new_value))) assert_weight_lists_match(server_state.model.trainable, trainable_weights) assert_weight_lists_match(server_state.model.non_trainable, non_trainable_weights) new_server_state = computation_utils.update_state( server_state, model=model_utils.ModelWeights(trainable=trainable_weights, non_trainable=non_trainable_weights)) return new_server_state
def test_update_state_fails(self): with self.assertRaisesRegex(TypeError, 'state must be a structure'): computation_utils.update_state((1, 2, 3), a=8) with self.assertRaisesRegex(TypeError, 'state must be a structure'): computation_utils.update_state([1, 2, 3], a=8) with self.assertRaisesRegex(KeyError, 'does not contain a field'): computation_utils.update_state({'z': 1}, a=8)
def test_update_state_dict(self): state = {'a': 1, 'b': 2, 'c': 3} state2 = computation_utils.update_state(state, c=7) self.assertEqual(state2, {'a': 1, 'b': 2, 'c': 7}) state3 = computation_utils.update_state(state2, a=8) self.assertEqual(state3, {'a': 8, 'b': 2, 'c': 7})
def test_update_state_dict(self): state = collections.OrderedDict([('a', 1), ('b', 2), ('c', 3)]) state2 = computation_utils.update_state(state, c=7) self.assertEqual(state2, {'a': 1, 'b': 2, 'c': 7}) state3 = computation_utils.update_state(state2, a=8) self.assertEqual(state3, {'a': 8, 'b': 2, 'c': 7})
def test_update_state_tff_struct(self): with self.subTest('fully_named'): state = Struct([('a', 1), ('b', 2), ('c', 3)]) state = computation_utils.update_state(state, c=7) self.assertEqual(state, Struct([('a', 1), ('b', 2), ('c', 7)])) state = computation_utils.update_state(state, a=8) self.assertEqual(state, Struct([('a', 8), ('b', 2), ('c', 7)])) with self.subTest('partially_named'): state = Struct([(None, 1), ('b', 2), (None, 3)]) state = computation_utils.update_state(state, b=7) self.assertEqual(state, Struct([(None, 1), ('b', 7), (None, 3)])) with self.assertRaises(KeyError): computation_utils.update_state(state, a=8) with self.subTest('nested'): state = Struct([('a', {'a1': 1, 'a2': 2}), ('b', 2), ('c', 3)]) state = computation_utils.update_state(state, a=7) self.assertEqual(state, Struct([('a', 7), ('b', 2), ('c', 3)])) state = computation_utils.update_state(state, a={'foo': 1, 'bar': 2}) self.assertEqual( state, Struct([('a', { 'foo': 1, 'bar': 2 }), ('b', 2), ('c', 3)])) with self.subTest('unnamed'): state = Struct((None, i) for i in range(3)) with self.assertRaises(KeyError): computation_utils.update_state(state, a=1) with self.assertRaises(KeyError): computation_utils.update_state(state, b=1)