Exemplo n.º 1
0
 def test_multiple_named_and_unnamed(self):
     v = [(None, 10), ('foo', 20), ('bar', 30)]
     x = anonymous_tuple.AnonymousTuple(v)
     self.assertLen(x, 3)
     self.assertEqual(x[0], 10)
     self.assertEqual(x[1], 20)
     self.assertEqual(x[2], 30)
     self.assertRaises(IndexError, lambda _: x[3], None)
     self.assertEqual(list(iter(x)), [10, 20, 30])
     self.assertEqual(dir(x), ['bar', 'foo'])
     self.assertEqual(x.foo, 20)
     self.assertEqual(x.bar, 30)
     self.assertRaises(AttributeError, lambda _: x.baz, None)
     self.assertEqual(
         x,
         anonymous_tuple.AnonymousTuple([(None, 10), ('foo', 20),
                                         ('bar', 30)]))
     self.assertNotEqual(
         x,
         anonymous_tuple.AnonymousTuple([('foo', 10), ('bar', 20),
                                         (None, 30)]))
     self.assertEqual(anonymous_tuple.to_elements(x), v)
     self.assertEqual(repr(x),
                      'AnonymousTuple([(None, 10), (foo, 20), (bar, 30)])')
     self.assertEqual(str(x), '<10,foo=20,bar=30>')
     with self.assertRaisesRegex(ValueError, 'unnamed'):
         anonymous_tuple.to_odict(x)
Exemplo n.º 2
0
 def _server_state_from_tff_result(self, result):
     if self._per_vector_clipping:
         per_vector_aggregate_states = [
             anonymous_tuple.to_odict(elt, recursive=True) for _, elt in
             anonymous_tuple.iter_elements(result.delta_aggregate_state)
         ]
     else:
         per_vector_aggregate_states = anonymous_tuple.to_odict(
             result.delta_aggregate_state, recursive=True)
     return tff.learning.framework.ServerState(
         tff.learning.ModelWeights(tuple(result.model.trainable),
                                   tuple(result.model.non_trainable)),
         list(result.optimizer_state), per_vector_aggregate_states,
         tuple(result.model_broadcast_state))
Exemplo n.º 3
0
def _get_accumulator_type(member_type):
    """Constructs a `tff.Type` for the accumulator in sample aggregation.

  Args:
    member_type: A `tff.Type` representing the member components of the
      federated type.

  Returns:
    The `tff.NamedTupleType` associated with the accumulator. The tuple contains
    two parts, `accumulators` and `rands`, that are parallel lists (e.g. the
    i-th index in one corresponds to the i-th index in the other). These two
    lists are used to sample from the accumulators with equal probability.
  """
    # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed.
    if isinstance(member_type, tff.NamedTupleType):
        a = anonymous_tuple.map_structure(
            lambda v: tff.TensorType(v.dtype, [None] + v.shape.dims),
            member_type)
        return tff.NamedTupleType(
            collections.OrderedDict({
                'accumulators':
                tff.NamedTupleType(anonymous_tuple.to_odict(a, True)),
                'rands':
                tff.TensorType(tf.float32, shape=[None])
            }))
    return tff.NamedTupleType(
        collections.OrderedDict({
            'accumulators':
            tff.TensorType(member_type.dtype,
                           shape=[None] + member_type.shape.dims),
            'rands':
            tff.TensorType(tf.float32, shape=[None])
        }))
Exemplo n.º 4
0
 def _server_state_from_tff_result(self, result):
     return tff.learning.framework.ServerState(
         tff.learning.ModelWeights(tuple(result.model.trainable),
                                   tuple(result.model.non_trainable)),
         list(result.optimizer_state),
         anonymous_tuple.to_odict(result.delta_aggregate_state, True),
         tuple(result.model_broadcast_state))
Exemplo n.º 5
0
 def accumlator_type_fn():
   # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed.
   if isinstance(member_type, tff.NamedTupleType):
     a = anonymous_tuple.map_structure(
         lambda v: tf.zeros([0] + v.shape.dims, v.dtype), member_type)
     return _Samples(anonymous_tuple.to_odict(a), tf.zeros([0], tf.float32))
   if member_type.shape:
     s = [0] + member_type.shape.dims
   return _Samples(tf.zeros(s, member_type.dtype), tf.zeros([0], tf.float32))
Exemplo n.º 6
0
 def test_single_unnamed(self):
     v = [(None, 10)]
     x = anonymous_tuple.AnonymousTuple(v)
     self.assertLen(x, 1)
     self.assertRaises(IndexError, lambda _: x[1], None)
     self.assertEqual(x[0], 10)
     self.assertEqual(list(iter(x)), [10])
     self.assertEqual(dir(x), [])
     self.assertRaises(AttributeError, lambda _: x.foo, None)
     self.assertNotEqual(x, anonymous_tuple.AnonymousTuple([]))
     self.assertNotEqual(x, anonymous_tuple.AnonymousTuple([('foo', 10)]))
     self.assertEqual(x, anonymous_tuple.AnonymousTuple([(None, 10)]))
     self.assertNotEqual(
         x, anonymous_tuple.AnonymousTuple([(None, 10), ('foo', 20)]))
     self.assertEqual(anonymous_tuple.to_elements(x), v)
     self.assertEqual(repr(x), 'AnonymousTuple([(None, 10)])')
     self.assertEqual(str(x), '<10>')
     with self.assertRaisesRegex(ValueError, 'unnamed'):
         anonymous_tuple.to_odict(x)
Exemplo n.º 7
0
 def accumlator_type_fn():
   """Gets the type for the accumulators."""
   # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed.
   if member_type.is_struct():
     a = anonymous_tuple.map_structure(
         lambda v: tf.zeros([0] + v.shape.dims, v.dtype), member_type)
     return _Samples(
         anonymous_tuple.to_odict(a, True), tf.zeros([0], tf.float32))
   if member_type.shape:
     s = [0] + member_type.shape.dims
   return _Samples(tf.zeros(s, member_type.dtype), tf.zeros([0], tf.float32))
Exemplo n.º 8
0
  def test_federated_sum_named_tuples(self):
    bodies = intrinsic_bodies.get_intrinsic_bodies(
        context_stack_impl.context_stack)

    @computations.federated_computation(
        computation_types.FederatedType([('a', tf.int32), ('b', tf.float32)],
                                        placements.CLIENTS))
    def foo(x):
      return bodies[intrinsic_defs.FEDERATED_SUM.uri](x)

    self.assertEqual(
        str(foo.type_signature),
        '({<a=int32,b=float32>}@CLIENTS -> <a=int32,b=float32>@SERVER)')
    self.assertDictEqual(
        anonymous_tuple.to_odict(foo([[1, 2.]])), {
            'a': 1,
            'b': 2.
        })
    self.assertDictEqual(
        anonymous_tuple.to_odict(foo([[1, 2.], [3, 4.]])), {
            'a': 4,
            'b': 6.
        })
Exemplo n.º 9
0
 def test_empty(self):
   v = []
   x = anonymous_tuple.AnonymousTuple(v)
   # Explicitly test the implementation of __len__() here so use, assertLen()
   # instead of assertEmpty().
   self.assertLen(x, 0)  # pylint: disable=g-generic-assert
   self.assertRaises(IndexError, lambda _: x[0], None)
   self.assertEqual(list(iter(x)), [])
   self.assertEqual(dir(x), [])
   self.assertRaises(AttributeError, lambda _: x.foo, None)
   self.assertEqual(x, anonymous_tuple.AnonymousTuple([]))
   self.assertNotEqual(x, anonymous_tuple.AnonymousTuple([('foo', 10)]))
   self.assertEqual(anonymous_tuple.to_elements(x), v)
   self.assertEqual(anonymous_tuple.to_odict(x), collections.OrderedDict())
   self.assertEqual(repr(x), 'AnonymousTuple([])')
   self.assertEqual(str(x), '<>')
Exemplo n.º 10
0
    def server_update_model_tf(server_state, model_delta):
      """Converts args to correct python types and calls server_update_model."""
      # We need to convert TFF types to the types server_update_model expects.
      # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not
      # pretty, fold this into anonymous_tuple module or get working with
      # tf.contrib.framework.nest.
      py_typecheck.check_type(model_delta, anonymous_tuple.AnonymousTuple)
      model_delta = anonymous_tuple.to_odict(model_delta)
      py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple)
      server_state = ServerState(
          model=model_utils.ModelWeights.from_tff_value(server_state.model),
          optimizer_state=list(server_state.optimizer_state))

      return server_update_model(
          server_state,
          model_delta,
          model_fn=model_fn,
          optimizer_fn=server_optimizer_fn)
Exemplo n.º 11
0
 def test_single_named(self):
   v = [('foo', 20)]
   x = anonymous_tuple.AnonymousTuple(v)
   self.assertLen(x, 1)
   self.assertEqual(x[0], 20)
   self.assertRaises(IndexError, lambda _: x[1], None)
   self.assertEqual(list(iter(x)), [20])
   self.assertEqual(dir(x), ['foo'])
   self.assertEqual(x.foo, 20)
   self.assertRaises(AttributeError, lambda _: x.bar, None)
   self.assertNotEqual(x, anonymous_tuple.AnonymousTuple([]))
   self.assertNotEqual(x, anonymous_tuple.AnonymousTuple([('foo', 10)]))
   self.assertNotEqual(x, anonymous_tuple.AnonymousTuple([(None, 20)]))
   self.assertEqual(x, anonymous_tuple.AnonymousTuple([('foo', 20)]))
   self.assertNotEqual(
       x, anonymous_tuple.AnonymousTuple([('foo', 20), ('bar', 30)]))
   self.assertEqual(anonymous_tuple.to_elements(x), v)
   self.assertEqual(repr(x), 'AnonymousTuple([(\'foo\', 20)])')
   self.assertEqual(str(x), '<foo=20>')
   self.assertEqual(anonymous_tuple.to_odict(x), collections.OrderedDict(v))
Exemplo n.º 12
0
 def from_tff_value(cls, anon_tuple):
   py_typecheck.check_type(anon_tuple, anonymous_tuple.AnonymousTuple)
   return cls(
       anonymous_tuple.to_odict(anon_tuple.trainable),
       anonymous_tuple.to_odict(anon_tuple.non_trainable))