def test_update_state_namedtuple(self):
     my_tuple_type = collections.namedtuple('my_tuple_type', 'a b c')
     state = my_tuple_type(1, 2, 3)
     state2 = computation_utils.update_state(state, c=7)
     self.assertEqual(state2, my_tuple_type(1, 2, 7))
     state3 = computation_utils.update_state(state2, a=8)
     self.assertEqual(state3, my_tuple_type(8, 2, 7))
 def test_update_state_ordereddict(self):
     state = collections.OrderedDict([('a', 1), ('b', 2), ('c', 3)])
     state2 = computation_utils.update_state(state, c=7)
     self.assertEqual(
         state2, collections.OrderedDict([('a', 1), ('b', 2), ('c', 7)]))
     state3 = computation_utils.update_state(state2, a=8)
     self.assertEqual(
         state3, collections.OrderedDict([('a', 8), ('b', 2), ('c', 7)]))
    def test_update_state(self):
        MyTuple = collections.namedtuple('MyTuple', 'a b c')  # pylint: disable=invalid-name
        t = MyTuple(1, 2, 3)
        t2 = computation_utils.update_state(t, c=7)
        self.assertEqual(t2, MyTuple(1, 2, 7))
        t3 = computation_utils.update_state(t2, a=8)
        self.assertEqual(t3, MyTuple(8, 2, 7))

        with six.assertRaisesRegex(self, TypeError, r'state.*namedtuple'):
            computation_utils.update_state((1, 2, 3), a=8)
    def test_update_state_attrs(self):
        @attr.s
        class TestAttrsClass(object):
            a = attr.ib()
            b = attr.ib()
            c = attr.ib()

        state = TestAttrsClass(1, 2, 3)
        state2 = computation_utils.update_state(state, c=7)
        self.assertEqual(state2, TestAttrsClass(1, 2, 7))
        state3 = computation_utils.update_state(state2, a=8)
        self.assertEqual(state3, TestAttrsClass(8, 2, 7))
Exemple #5
0
def state_with_new_model_weights(
    server_state: ServerState,
    trainable_weights: List[np.ndarray],
    non_trainable_weights: List[np.ndarray],
) -> ServerState:
    """Returns a `ServerState` with updated model weights.

  Args:
    server_state: a server state object returned by an iterative training
      process like `tff.learning.build_federated_averaging_process`.
    trainable_weights: a list of `numpy` arrays in the order of the original
      model's `trainable_variables`.
    non_trainable_weights: a list of `numpy` arrays in the order of the original
      model's `non_trainable_variables`.

  Returns:
    A new server `ServerState` object which can be passed to the `next` method
    of the iterative process.
  """
    py_typecheck.check_type(server_state, ServerState)
    leaf_types = (int, float, np.ndarray, tf.Tensor)

    def assert_weight_lists_match(old_value, new_value):
        """Assert two flat lists of ndarrays or tensors match."""
        if isinstance(new_value, leaf_types) and isinstance(
                old_value, leaf_types):
            if (old_value.dtype != new_value.dtype
                    or old_value.shape != new_value.shape):
                raise TypeError('Element is not the same tensor type. old '
                                f'({old_value.dtype}, {old_value.shape}) != '
                                f'new ({new_value.dtype}, {new_value.shape})')
        elif (isinstance(new_value, collections.Sequence)
              and isinstance(old_value, collections.Sequence)):
            if len(old_value) != len(new_value):
                raise TypeError(
                    'Model weights have different lengths: '
                    f'(old) {len(old_value)} != (new) {len(new_value)})\n'
                    f'Old values: {old_value}\nNew values: {new_value}')
            for old, new in zip(old_value, new_value):
                assert_weight_lists_match(old, new)
        else:
            raise TypeError(
                'Model weights structures contains types that cannot be '
                'handled.\nOld weights structure: {old}\n'
                'New weights structure: {new}\n'
                'Must be one of (int, float, np.ndarray, tf.Tensor, '
                'collections.Sequence)'.format(
                    old=tf.nest.map_structure(type, old_value),
                    new=tf.nest.map_structure(type, new_value)))

    assert_weight_lists_match(server_state.model.trainable, trainable_weights)
    assert_weight_lists_match(server_state.model.non_trainable,
                              non_trainable_weights)
    new_server_state = computation_utils.update_state(
        server_state,
        model=model_utils.ModelWeights(trainable=trainable_weights,
                                       non_trainable=non_trainable_weights))
    return new_server_state
 def test_update_state_fails(self):
     with self.assertRaisesRegex(TypeError, 'state must be a structure'):
         computation_utils.update_state((1, 2, 3), a=8)
     with self.assertRaisesRegex(TypeError, 'state must be a structure'):
         computation_utils.update_state([1, 2, 3], a=8)
     with self.assertRaisesRegex(KeyError, 'does not contain a field'):
         computation_utils.update_state({'z': 1}, a=8)
 def test_update_state_dict(self):
     state = {'a': 1, 'b': 2, 'c': 3}
     state2 = computation_utils.update_state(state, c=7)
     self.assertEqual(state2, {'a': 1, 'b': 2, 'c': 7})
     state3 = computation_utils.update_state(state2, a=8)
     self.assertEqual(state3, {'a': 8, 'b': 2, 'c': 7})
 def test_update_state_dict(self):
     state = collections.OrderedDict([('a', 1), ('b', 2), ('c', 3)])
     state2 = computation_utils.update_state(state, c=7)
     self.assertEqual(state2, {'a': 1, 'b': 2, 'c': 7})
     state3 = computation_utils.update_state(state2, a=8)
     self.assertEqual(state3, {'a': 8, 'b': 2, 'c': 7})
 def test_update_state_tff_struct(self):
   with self.subTest('fully_named'):
     state = Struct([('a', 1), ('b', 2), ('c', 3)])
     state = computation_utils.update_state(state, c=7)
     self.assertEqual(state, Struct([('a', 1), ('b', 2), ('c', 7)]))
     state = computation_utils.update_state(state, a=8)
     self.assertEqual(state, Struct([('a', 8), ('b', 2), ('c', 7)]))
   with self.subTest('partially_named'):
     state = Struct([(None, 1), ('b', 2), (None, 3)])
     state = computation_utils.update_state(state, b=7)
     self.assertEqual(state, Struct([(None, 1), ('b', 7), (None, 3)]))
     with self.assertRaises(KeyError):
       computation_utils.update_state(state, a=8)
   with self.subTest('nested'):
     state = Struct([('a', {'a1': 1, 'a2': 2}), ('b', 2), ('c', 3)])
     state = computation_utils.update_state(state, a=7)
     self.assertEqual(state, Struct([('a', 7), ('b', 2), ('c', 3)]))
     state = computation_utils.update_state(state, a={'foo': 1, 'bar': 2})
     self.assertEqual(
         state, Struct([('a', {
             'foo': 1,
             'bar': 2
         }), ('b', 2), ('c', 3)]))
   with self.subTest('unnamed'):
     state = Struct((None, i) for i in range(3))
     with self.assertRaises(KeyError):
       computation_utils.update_state(state, a=1)
     with self.assertRaises(KeyError):
       computation_utils.update_state(state, b=1)