Exemplo n.º 1
0
    async def test_get_value_returns_value(self, coro, type_signature,
                                           expected_value):
        value_reference = native_platform.CoroValueReference(
            coro=coro, type_signature=type_signature)

        actual_value = await value_reference.get_value()

        self.assertEqual(actual_value, expected_value)
Exemplo n.º 2
0
    def test_init_does_not_raise_type_error_with_type_signature(
            self, coro, type_signature):

        try:
            native_platform.CoroValueReference(coro=coro,
                                               type_signature=type_signature)
        except TypeError:
            self.fail('Raised TypeError unexpectedly.')
Exemplo n.º 3
0
class CreateStructureOfCoroReferencesTest(parameterized.TestCase,
                                          unittest.IsolatedAsyncioTestCase):

    # pyformat: disable
    @parameterized.named_parameters(
        ('tensor', _coro(1), computation_types.TensorType(tf.int32),
         native_platform.CoroValueReference(
             _coro(1), computation_types.TensorType(tf.int32))),
        ('federated', _coro(1),
         computation_types.FederatedType(
             computation_types.TensorType(tf.int32), placements.SERVER),
         native_platform.CoroValueReference(
             _coro(1), computation_types.TensorType(tf.int32))),
        ('struct_unnamed', _coro((True, 1, 'a')),
         computation_types.StructType([
             (None, computation_types.TensorType(tf.bool)),
             (None, computation_types.TensorType(tf.int32)),
             (None, computation_types.TensorType(tf.string)),
         ]),
         structure.Struct([
             (None,
              native_platform.CoroValueReference(
                  _coro(True), computation_types.TensorType(tf.bool))),
             (None,
              native_platform.CoroValueReference(
                  _coro(1), computation_types.TensorType(tf.int32))),
             (None,
              native_platform.CoroValueReference(
                  _coro('a'), computation_types.TensorType(tf.string))),
         ])),
        ('struct_named',
         _coro(collections.OrderedDict([('a', True), ('b', 1), ('c', 'a')])),
         computation_types.StructType([
             ('a', computation_types.TensorType(tf.bool)),
             ('b', computation_types.TensorType(tf.int32)),
             ('c', computation_types.TensorType(tf.string)),
         ]),
         structure.Struct([
             ('a',
              native_platform.CoroValueReference(
                  _coro(True), computation_types.TensorType(tf.bool))),
             ('b',
              native_platform.CoroValueReference(
                  _coro(1), computation_types.TensorType(tf.int32))),
             ('c',
              native_platform.CoroValueReference(
                  _coro('a'), computation_types.TensorType(tf.string))),
         ])),
        ('struct_nested',
         _coro(
             collections.OrderedDict([
                 ('x', collections.OrderedDict([('a', True), ('b', 1)])),
                 ('y', collections.OrderedDict([('c', 'a')])),
             ])),
         computation_types.StructType([
             ('x',
              computation_types.StructType([
                  ('a', computation_types.TensorType(tf.bool)),
                  ('b', computation_types.TensorType(tf.int32)),
              ])),
             ('y',
              computation_types.StructType([
                  ('c', computation_types.TensorType(tf.string)),
              ])),
         ]),
         structure.Struct([
             ('x',
              structure.Struct([
                  ('a',
                   native_platform.CoroValueReference(
                       _coro(True), computation_types.TensorType(tf.bool))),
                  ('b',
                   native_platform.CoroValueReference(
                       _coro(1), computation_types.TensorType(tf.int32))),
              ])),
             ('y',
              structure.Struct([
                  ('c',
                   native_platform.CoroValueReference(
                       _coro('a'), computation_types.TensorType(tf.string))),
              ])),
         ])),
    )
    # pyformat: enable
    async def test_returns_value(self, coro, type_signature, expected_value):
        actual_value = native_platform._create_structure_of_coro_references(
            coro=coro, type_signature=type_signature)

        if (isinstance(actual_value, structure.Struct)
                and isinstance(expected_value, structure.Struct)):
            structure.is_same_structure(actual_value, expected_value)
            actual_value = structure.flatten(actual_value)
            expected_value = structure.flatten(expected_value)
            for a, b in zip(actual_value, expected_value):
                a = await a.get_value()
                b = await b.get_value()
                self.assertEqual(a, b)
        else:
            actual_value = await actual_value.get_value()
            expected_value = await expected_value.get_value()
            self.assertEqual(actual_value, expected_value)

    @parameterized.named_parameters(
        ('none', None),
        ('bool', True),
        ('int', 1),
        ('str', 'a'),
        ('list', []),
    )
    def test_raises_type_error_with_type_signature(self, type_signature):
        coro = _coro(1)

        with self.assertRaises(TypeError):
            native_platform._create_structure_of_coro_references(
                coro=coro, type_signature=type_signature)

    # pyformat: disable
    @parameterized.named_parameters(
        ('federated',
         computation_types.FederatedType(
             computation_types.TensorType(tf.int32), placements.CLIENTS)),
        ('function',
         computation_types.FunctionType(computation_types.TensorType(
             tf.int32), computation_types.TensorType(tf.int32))),
        ('placement', computation_types.PlacementType()),
    )
    # pyformat: enable
    def test_raises_not_implemented_error_with_type_signature(
            self, type_signature):
        coro = _coro(1)

        with self.assertRaises(NotImplementedError):
            native_platform._create_structure_of_coro_references(
                coro=coro, type_signature=type_signature)

    async def test_returned_structure_materialized_sequentially(self):
        coro = _coro((True, 1, 'a'))
        type_signature = computation_types.StructType([
            (None, computation_types.TensorType(tf.bool)),
            (None, computation_types.TensorType(tf.int32)),
            (None, computation_types.TensorType(tf.string)),
        ])

        result = native_platform._create_structure_of_coro_references(
            coro=coro, type_signature=type_signature)

        actual_values = []
        for value in result:
            actual_value = await value.get_value()
            actual_values.append(actual_value)
        expected_values = [True, 1, 'a']
        self.assertEqual(actual_values, expected_values)

    async def test_returned_structure_materialized_concurrently(self):
        coro = _coro((True, 1, 'a'))
        type_signature = computation_types.StructType([
            (None, computation_types.TensorType(tf.bool)),
            (None, computation_types.TensorType(tf.int32)),
            (None, computation_types.TensorType(tf.string)),
        ])

        result = native_platform._create_structure_of_coro_references(
            coro=coro, type_signature=type_signature)

        actual_values = await asyncio.gather(*[v.get_value() for v in result])
        expected_values = [True, 1, 'a']
        self.assertEqual(actual_values, expected_values)
Exemplo n.º 4
0
    def test_init_raises_type_error_with_type_signature(self, type_signature):
        coro = _coro(1)

        with self.assertRaises(TypeError):
            native_platform.CoroValueReference(coro=coro,
                                               type_signature=type_signature)
Exemplo n.º 5
0
    def test_init_raises_type_error_with_coro(self, coro):
        type_signature = computation_types.TensorType(tf.int32)

        with self.assertRaises(TypeError):
            native_platform.CoroValueReference(coro=coro,
                                               type_signature=type_signature)
Exemplo n.º 6
0
class MaterializeStructureOfValueReferencesTest(
        parameterized.TestCase, unittest.IsolatedAsyncioTestCase):

    # pyformat: disable
    @parameterized.named_parameters(
        ('tensor',
         native_platform.CoroValueReference(
             _coro(1), computation_types.TensorType(
                 tf.int32)), computation_types.TensorType(tf.int32), 1),
        ('federated',
         native_platform.CoroValueReference(
             _coro(1), computation_types.TensorType(tf.int32)),
         computation_types.FederatedType(
             computation_types.TensorType(tf.int32), placements.SERVER), 1),
        ('struct_unnamed', (
            native_platform.CoroValueReference(
                _coro(True), computation_types.TensorType(tf.bool)),
            native_platform.CoroValueReference(
                _coro(1), computation_types.TensorType(tf.int32)),
            native_platform.CoroValueReference(
                _coro('a'), computation_types.TensorType(tf.string)),
        ),
         computation_types.StructType([
             (None, computation_types.TensorType(tf.bool)),
             (None, computation_types.TensorType(tf.int32)),
             (None, computation_types.TensorType(tf.string)),
         ]), structure.Struct([
             (None, True),
             (None, 1),
             (None, 'a'),
         ])),
        ('struct_named',
         collections.OrderedDict([
             ('a',
              native_platform.CoroValueReference(
                  _coro(True), computation_types.TensorType(tf.bool))),
             ('b',
              native_platform.CoroValueReference(
                  _coro(1), computation_types.TensorType(tf.int32))),
             ('c',
              native_platform.CoroValueReference(
                  _coro('a'), computation_types.TensorType(tf.string))),
         ]),
         computation_types.StructType([
             ('a', computation_types.TensorType(tf.bool)),
             ('b', computation_types.TensorType(tf.int32)),
             ('c', computation_types.TensorType(tf.string)),
         ]), structure.Struct([
             ('a', True),
             ('b', 1),
             ('c', 'a'),
         ])),
        ('struct_nested',
         collections.OrderedDict([
             ('x',
              collections.OrderedDict([
                  ('a',
                   native_platform.CoroValueReference(
                       _coro(True), computation_types.TensorType(tf.bool))),
                  ('b',
                   native_platform.CoroValueReference(
                       _coro(1), computation_types.TensorType(tf.int32))),
              ])),
             ('y',
              collections.OrderedDict([
                  ('c',
                   native_platform.CoroValueReference(
                       _coro('a'), computation_types.TensorType(tf.string))),
              ])),
         ]),
         computation_types.StructType([
             ('x',
              computation_types.StructType([
                  ('a', computation_types.TensorType(tf.bool)),
                  ('b', computation_types.TensorType(tf.int32)),
              ])),
             ('y',
              computation_types.StructType([
                  ('c', computation_types.TensorType(tf.string)),
              ])),
         ]),
         structure.Struct([
             ('x', structure.Struct([
                 ('a', True),
                 ('b', 1),
             ])),
             ('y', structure.Struct([
                 ('c', 'a'),
             ])),
         ])),
    )
    # pyformat: enable
    async def test_returns_value(self, value, type_signature, expected_value):
        actual_value = await native_platform._materialize_structure_of_value_references(
            value=value, type_signature=type_signature)

        self.assertEqual(actual_value, expected_value)

    @parameterized.named_parameters(
        ('none', None),
        ('bool', True),
        ('int', 1),
        ('str', 'a'),
        ('list', []),
    )
    async def test_raises_type_error_with_type_signature(self, type_signature):
        with self.assertRaises(TypeError):
            await native_platform._materialize_structure_of_value_references(
                value=1, type_signature=type_signature)