Beispiel #1
0
 def test_multiple_named_and_unnamed(self):
   v = [(None, 10), ('foo', 20), ('bar', 30)]
   x = structure.Struct(v)
   self.assertLen(x, 3)
   self.assertEqual(x[0], 10)
   self.assertEqual(x[1], 20)
   self.assertEqual(x[2], 30)
   self.assertRaises(IndexError, lambda _: x[3], None)
   self.assertEqual(list(iter(x)), [10, 20, 30])
   self.assertEqual(dir(x), ['bar', 'foo'])
   self.assertEqual(structure.name_list(x), ['foo', 'bar'])
   self.assertEqual(x.foo, 20)
   self.assertEqual(x.bar, 30)
   self.assertRaises(AttributeError, lambda _: x.baz, None)
   self.assertEqual(x, structure.Struct([(None, 10), ('foo', 20),
                                         ('bar', 30)]))
   self.assertNotEqual(
       x, structure.Struct([('foo', 10), ('bar', 20), (None, 30)]))
   self.assertEqual(structure.to_elements(x), v)
   self.assertEqual(
       repr(x), 'Struct([(None, 10), (\'foo\', 20), (\'bar\', 30)])')
   self.assertEqual(str(x), '<10,foo=20,bar=30>')
   with self.assertRaisesRegex(ValueError, 'unnamed'):
     structure.to_odict(x)
   with self.assertRaisesRegex(ValueError, 'named and unnamed'):
     structure.to_odict_or_tuple(x)
def _get_accumulator_type(member_type):
  """Constructs a `tff.Type` for the accumulator in sample aggregation.

  Args:
    member_type: A `tff.Type` representing the member components of the
      federated type.

  Returns:
    The `tff.StructType` associated with the accumulator. The tuple contains
    two parts, `accumulators` and `rands`, that are parallel lists (e.g. the
    i-th index in one corresponds to the i-th index in the other). These two
    lists are used to sample from the accumulators with equal probability.
  """
  # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed.
  if member_type.is_struct():
    a = structure.map_structure(
        lambda v: computation_types.TensorType(v.dtype, [None] + v.shape.dims),
        member_type)
    return computation_types.StructType(
        collections.OrderedDict({
            'accumulators':
                computation_types.StructType(structure.to_odict(a, True)),
            'rands':
                computation_types.TensorType(tf.float32, shape=[None]),
        }))
  return computation_types.StructType(
      collections.OrderedDict({
          'accumulators':
              computation_types.TensorType(
                  member_type.dtype, shape=[None] + member_type.shape.dims),
          'rands':
              computation_types.TensorType(tf.float32, shape=[None]),
      }))
Beispiel #3
0
 def test_single_unnamed(self):
   v = [(None, 10)]
   x = structure.Struct(v)
   self.assertLen(x, 1)
   self.assertRaises(IndexError, lambda _: x[1], None)
   self.assertEqual(x[0], 10)
   self.assertEqual(list(iter(x)), [10])
   self.assertEqual(dir(x), [])
   self.assertRaises(AttributeError, lambda _: x.foo, None)
   self.assertNotEqual(x, structure.Struct([]))
   self.assertNotEqual(x, structure.Struct([('foo', 10)]))
   self.assertEqual(x, structure.Struct([(None, 10)]))
   self.assertNotEqual(x, structure.Struct([(None, 10), ('foo', 20)]))
   self.assertEqual(structure.to_elements(x), v)
   self.assertEqual(repr(x), 'Struct([(None, 10)])')
   self.assertEqual(str(x), '<10>')
   with self.assertRaisesRegex(ValueError, 'unnamed'):
     structure.to_odict(x)
 def accumlator_type_fn():
   """Gets the type for the accumulators."""
   # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed.
   if member_type.is_struct():
     a = structure.map_structure(
         lambda v: tf.zeros([0] + v.shape.dims, v.dtype), member_type)
     return _Samples(structure.to_odict(a, True), tf.zeros([0], tf.float32))
   if member_type.shape:
     s = [0] + member_type.shape.dims
   return _Samples(tf.zeros(s, member_type.dtype), tf.zeros([0], tf.float32))
  def test_federated_sum_named_tuples(self):
    bodies = intrinsic_bodies.get_intrinsic_bodies(
        context_stack_impl.context_stack)

    @computations.federated_computation(
        computation_types.FederatedType([('a', tf.int32), ('b', tf.float32)],
                                        placement_literals.CLIENTS))
    def foo(x):
      return bodies[intrinsic_defs.FEDERATED_SUM.uri](x)

    self.assertEqual(
        str(foo.type_signature),
        '({<a=int32,b=float32>}@CLIENTS -> <a=int32,b=float32>@SERVER)')
    self.assertDictEqual(structure.to_odict(foo([[1, 2.]])), {'a': 1, 'b': 2.})
    self.assertDictEqual(
        structure.to_odict(foo([[1, 2.], [3, 4.]])), {
            'a': 4,
            'b': 6.
        })
Beispiel #6
0
 def test_empty(self):
   v = []
   x = structure.Struct(v)
   # Explicitly test the implementation of __len__() here so use, assertLen()
   # instead of assertEmpty().
   self.assertLen(x, 0)  # pylint: disable=g-generic-assert
   self.assertRaises(IndexError, lambda _: x[0], None)
   self.assertEqual(list(iter(x)), [])
   self.assertEqual(dir(x), [])
   self.assertRaises(AttributeError, lambda _: x.foo, None)
   self.assertEqual(x, structure.Struct([]))
   self.assertNotEqual(x, structure.Struct([('foo', 10)]))
   self.assertEqual(structure.to_elements(x), v)
   self.assertEqual(structure.to_odict(x), collections.OrderedDict())
   self.assertEqual(repr(x), 'Struct([])')
   self.assertEqual(str(x), '<>')
Beispiel #7
0
 def test_single_named(self):
   v = [('foo', 20)]
   x = structure.Struct(v)
   self.assertLen(x, 1)
   self.assertEqual(x[0], 20)
   self.assertRaises(IndexError, lambda _: x[1], None)
   self.assertEqual(list(iter(x)), [20])
   self.assertEqual(dir(x), ['foo'])
   self.assertEqual(x.foo, 20)
   self.assertRaises(AttributeError, lambda _: x.bar, None)
   self.assertNotEqual(x, structure.Struct([]))
   self.assertNotEqual(x, structure.Struct([('foo', 10)]))
   self.assertNotEqual(x, structure.Struct([(None, 20)]))
   self.assertEqual(x, structure.Struct([('foo', 20)]))
   self.assertNotEqual(x, structure.Struct([('foo', 20), ('bar', 30)]))
   self.assertEqual(structure.to_elements(x), v)
   self.assertEqual(repr(x), 'Struct([(\'foo\', 20)])')
   self.assertEqual(str(x), '<foo=20>')
   self.assertEqual(structure.to_odict(x), collections.OrderedDict(v))