Exemple #1
0
 def test_update_struct_namedtuple(self):
   my_tuple_type = collections.namedtuple('my_tuple_type', 'a b c')
   state = my_tuple_type(1, 2, 3)
   state2 = structure.update_struct(state, c=7)
   self.assertEqual(state2, my_tuple_type(1, 2, 7))
   state3 = structure.update_struct(state2, a=8)
   self.assertEqual(state3, my_tuple_type(8, 2, 7))
Exemple #2
0
 def test_update_struct_ordereddict(self):
   state = collections.OrderedDict([('a', 1), ('b', 2), ('c', 3)])
   state2 = structure.update_struct(state, c=7)
   self.assertEqual(state2,
                    collections.OrderedDict([('a', 1), ('b', 2), ('c', 7)]))
   state3 = structure.update_struct(state2, a=8)
   self.assertEqual(state3,
                    collections.OrderedDict([('a', 8), ('b', 2), ('c', 7)]))
Exemple #3
0
  def test_update_struct_attrs(self):

    @attr.s
    class TestAttrsClass(object):
      a = attr.ib()
      b = attr.ib()
      c = attr.ib()

    state = TestAttrsClass(1, 2, 3)
    state2 = structure.update_struct(state, c=7)
    self.assertEqual(state2, TestAttrsClass(1, 2, 7))
    state3 = structure.update_struct(state2, a=8)
    self.assertEqual(state3, TestAttrsClass(8, 2, 7))
Exemple #4
0
 def test_update_struct_fails(self):
   with self.assertRaisesRegex(TypeError, '`structure` must be a structure'):
     structure.update_struct((1, 2, 3), a=8)
   with self.assertRaisesRegex(TypeError, '`structure` must be a structure'):
     structure.update_struct([1, 2, 3], a=8)
   with self.assertRaisesRegex(KeyError, 'does not contain a field'):
     structure.update_struct({'z': 1}, a=8)
Exemple #5
0
    def server_update(server_state, weights_delta, aggregator_state,
                      broadcaster_state):
        """Updates the `server_state` based on `weights_delta`.

    Args:
      server_state: A `tff.learning.framework.ServerState`, the state to be
        updated.
      weights_delta: The model delta in global trainable variables from clients.
      aggregator_state: The state of the aggregator after performing
        aggregation.
      broadcaster_state: The state of the broadcaster after broadcasting.

    Returns:
      The updated `tff.learning.framework.ServerState`.
    """
        with tf.init_scope():
            model = model_fn()
        global_model_weights = reconstruction_utils.get_global_variables(model)
        optimizer = keras_optimizer.build_or_verify_tff_optimizer(
            server_optimizer_fn,
            global_model_weights.trainable,
            disjoint_init_and_next=True)
        optimizer_state = server_state.optimizer_state

        # Initialize the model with the current state.
        tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights,
                              server_state.model)

        weights_delta, has_non_finite_weight = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))

        # We ignore the update if the weights_delta is non finite.
        if tf.equal(has_non_finite_weight, 0):
            negative_weights_delta = tf.nest.map_structure(
                lambda w: -1.0 * w, weights_delta)
            optimizer_state, updated_weights = optimizer.next(
                optimizer_state, global_model_weights.trainable,
                negative_weights_delta)
            if not isinstance(optimizer, keras_optimizer.KerasOptimizer):
                # Keras optimizer mutates model variables within the `next` step.
                tf.nest.map_structure(lambda a, b: a.assign(b),
                                      global_model_weights.trainable,
                                      updated_weights)

        # Create a new state based on the updated model.
        return structure.update_struct(
            server_state,
            model=global_model_weights,
            optimizer_state=optimizer_state,
            model_broadcast_state=broadcaster_state,
            delta_aggregate_state=aggregator_state,
        )
Exemple #6
0
 def test_update_struct_dict(self):
   state = collections.OrderedDict([('a', 1), ('b', 2), ('c', 3)])
   state2 = structure.update_struct(state, c=7)
   self.assertEqual(state2, {'a': 1, 'b': 2, 'c': 7})
   state3 = structure.update_struct(state2, a=8)
   self.assertEqual(state3, {'a': 8, 'b': 2, 'c': 7})
Exemple #7
0
 def test_update_struct(self):
   with self.subTest('fully_named'):
     state = structure.Struct([('a', 1), ('b', 2), ('c', 3)])
     state = structure.update_struct(state, c=7)
     self.assertEqual(state, structure.Struct([('a', 1), ('b', 2), ('c', 7)]))
     state = structure.update_struct(state, a=8)
     self.assertEqual(state, structure.Struct([('a', 8), ('b', 2), ('c', 7)]))
   with self.subTest('partially_named'):
     state = structure.Struct([(None, 1), ('b', 2), (None, 3)])
     state = structure.update_struct(state, b=7)
     self.assertEqual(state, structure.Struct([(None, 1), ('b', 7),
                                               (None, 3)]))
     with self.assertRaises(KeyError):
       structure.update_struct(state, a=8)
   with self.subTest('nested'):
     state = structure.Struct([('a', {'a1': 1, 'a2': 2}), ('b', 2), ('c', 3)])
     state = structure.update_struct(state, a=7)
     self.assertEqual(state, structure.Struct([('a', 7), ('b', 2), ('c', 3)]))
     state = structure.update_struct(state, a={'foo': 1, 'bar': 2})
     self.assertEqual(
         state,
         structure.Struct([('a', {
             'foo': 1,
             'bar': 2
         }), ('b', 2), ('c', 3)]))
   with self.subTest('unnamed'):
     state = structure.Struct((None, i) for i in range(3))
     with self.assertRaises(KeyError):
       structure.update_struct(state, a=1)
     with self.assertRaises(KeyError):
       structure.update_struct(state, b=1)
Exemple #8
0
 def test_update_struct_on_dict_does_not_mutate_original(self):
   state = collections.OrderedDict(a=1, b=2, c=3)
   state2 = structure.update_struct(state, c=7)
   del state2
   self.assertEqual(state, collections.OrderedDict(a=1, b=2, c=3))
Exemple #9
0
 def test_update_struct(self):
   with self.subTest('fully_named'):
     state = structure.Struct.named(a=1, b=2, c=3)
     state = structure.update_struct(state, c=7)
     self.assertEqual(state, structure.Struct.named(a=1, b=2, c=7))
     state = structure.update_struct(state, a=8)
     self.assertEqual(state, structure.Struct.named(a=8, b=2, c=7))
   with self.subTest('partially_named'):
     state = structure.Struct([(None, 1), ('b', 2), (None, 3)])
     state = structure.update_struct(state, b=7)
     self.assertEqual(state, structure.Struct([(None, 1), ('b', 7),
                                               (None, 3)]))
     with self.assertRaises(KeyError):
       structure.update_struct(state, a=8)
   with self.subTest('nested'):
     state = structure.Struct.named(a=dict(a1=1, a2=2), b=2, c=3)
     state = structure.update_struct(state, a=7)
     self.assertEqual(state, structure.Struct.named(a=7, b=2, c=3))
     state = structure.update_struct(state, a=dict(foo=1, bar=2))
     self.assertEqual(state,
                      structure.Struct.named(a=dict(foo=1, bar=2), b=2, c=3))
   with self.subTest('unnamed'):
     state = structure.Struct.unnamed(*tuple(range(3)))
     with self.assertRaises(KeyError):
       structure.update_struct(state, a=1)
     with self.assertRaises(KeyError):
       structure.update_struct(state, b=1)