예제 #1
0
    def test_nested_py_containers(self):
        anon_tuple = anonymous_tuple.AnonymousTuple([
            (None, 1), (None, 2.0),
            ('dict_key',
             anonymous_tuple.AnonymousTuple([
                 ('a', 3),
                 ('b', anonymous_tuple.AnonymousTuple([(None, 4), (None, 5)]))
             ]))
        ])

        dict_subtype = computation_types.StructWithPythonType(
            [('a', tf.int32),
             ('b',
              computation_types.StructWithPythonType([tf.int32, tf.int32],
                                                     tuple))], dict)
        type_spec = computation_types.StructType([(None, tf.int32),
                                                  (None, tf.float32),
                                                  ('dict_key', dict_subtype)])

        expected_nested_structure = anonymous_tuple.AnonymousTuple([
            (None, 1),
            (None, 2.0),
            ('dict_key', {
                'a': 3,
                'b': (4, 5)
            }),
        ])

        self.assertEqual(
            type_conversions.type_to_py_container(anon_tuple, type_spec),
            expected_nested_structure)
예제 #2
0
    def test_anon_tuple_with_names_to_container_with_names(self):
        anon_tuple = anonymous_tuple.AnonymousTuple([('a', 1), ('b', 2.0)])
        types = [('a', tf.int32), ('b', tf.float32)]
        self.assertDictEqual(
            type_conversions.type_to_py_container(
                anon_tuple,
                computation_types.StructWithPythonType(types, dict)), {
                    'a': 1,
                    'b': 2.0
                })
        self.assertSequenceEqual(
            type_conversions.type_to_py_container(
                anon_tuple,
                computation_types.StructWithPythonType(
                    types, collections.OrderedDict)),
            collections.OrderedDict([('a', 1), ('b', 2.0)]))
        test_named_tuple = collections.namedtuple('TestNamedTuple', ['a', 'b'])
        self.assertSequenceEqual(
            type_conversions.type_to_py_container(
                anon_tuple,
                computation_types.StructWithPythonType(types,
                                                       test_named_tuple)),
            test_named_tuple(a=1, b=2.0))

        @attr.s
        class TestFoo(object):
            a = attr.ib()
            b = attr.ib()

        self.assertEqual(
            type_conversions.type_to_py_container(
                anon_tuple,
                computation_types.StructWithPythonType(types, TestFoo)),
            TestFoo(a=1, b=2.0))
예제 #3
0
 def test_with_two_level_tuple(self):
     type_signature = computation_types.StructWithPythonType([
         ('a', tf.bool),
         ('b',
          computation_types.StructWithPythonType([
              ('c', computation_types.TensorType(tf.float32)),
              ('d', computation_types.TensorType(tf.int32, [20])),
          ], collections.OrderedDict)),
         ('e', computation_types.StructType([])),
     ], collections.OrderedDict)
     dtypes, shapes = type_conversions.type_to_tf_dtypes_and_shapes(
         type_signature)
     test.assert_nested_struct_eq(dtypes, {
         'a': tf.bool,
         'b': {
             'c': tf.float32,
             'd': tf.int32
         },
         'e': (),
     })
     test.assert_nested_struct_eq(
         shapes, {
             'a': tf.TensorShape([]),
             'b': {
                 'c': tf.TensorShape([]),
                 'd': tf.TensorShape([20])
             },
             'e': (),
         })
예제 #4
0
    def test_anon_tuple_without_names_to_container_with_names_fails(self):
        anon_tuple = anonymous_tuple.AnonymousTuple([(None, 1), (None, 2.0)])
        types = [('a', tf.int32), ('b', tf.float32)]
        with self.assertRaisesRegex(ValueError,
                                    'value.*with unnamed elements'):
            type_conversions.type_to_py_container(
                anon_tuple,
                computation_types.StructWithPythonType(types, dict))

        with self.assertRaisesRegex(ValueError,
                                    'value.*with unnamed elements'):
            type_conversions.type_to_py_container(
                anon_tuple,
                computation_types.StructWithPythonType(
                    types, collections.OrderedDict))

        test_named_tuple = collections.namedtuple('TestNamedTuple', ['a', 'b'])
        with self.assertRaisesRegex(ValueError,
                                    'value.*with unnamed elements'):
            type_conversions.type_to_py_container(
                anon_tuple,
                computation_types.StructWithPythonType(types,
                                                       test_named_tuple))

        @attr.s
        class TestFoo(object):
            a = attr.ib()
            b = attr.ib()

        with self.assertRaisesRegex(ValueError,
                                    'value.*with unnamed elements'):
            type_conversions.type_to_py_container(
                anon_tuple,
                computation_types.StructWithPythonType(types, TestFoo))
예제 #5
0
class IsSumCompatibleTest(parameterized.TestCase):
    @parameterized.named_parameters([
        ('tensor_type', computation_types.TensorType(tf.int32)),
        ('tuple_type_int',
         computation_types.StructType([tf.int32, tf.int32], )),
        ('tuple_type_float',
         computation_types.StructType([tf.complex128, tf.float32,
                                       tf.float64])),
        ('federated_type',
         computation_types.FederatedType(tf.int32, placements.CLIENTS)),
    ])
    def test_positive_examples(self, type_spec):
        type_analysis.check_is_sum_compatible(type_spec)

    @parameterized.named_parameters([
        ('tensor_type_bool', computation_types.TensorType(tf.bool)),
        ('tensor_type_string', computation_types.TensorType(tf.string)),
        ('partially_defined_shape',
         computation_types.TensorType(tf.int32, shape=[None])),
        ('tuple_type', computation_types.StructType([tf.int32, tf.bool])),
        ('sequence_type', computation_types.SequenceType(tf.int32)),
        ('placement_type', computation_types.PlacementType()),
        ('function_type', computation_types.FunctionType(tf.int32, tf.int32)),
        ('abstract_type', computation_types.AbstractType('T')),
        ('ragged_tensor',
         computation_types.StructWithPythonType([], tf.RaggedTensor)),
        ('sparse_tensor',
         computation_types.StructWithPythonType([], tf.SparseTensor)),
    ])
    def test_negative_examples(self, type_spec):
        with self.assertRaises(type_analysis.SumIncompatibleError):
            type_analysis.check_is_sum_compatible(type_spec)
예제 #6
0
class FnToBuildingBlockTest(parameterized.TestCase):

    # pyformat: disable
    @parameterized.named_parameters((
        'nested_fn_same', lambda f, x: f(f(x)),
        computation_types.StructType(
            [('f', computation_types.FunctionType(tf.int32, tf.int32)),
             ('x', tf.int32)]),
        '(FEDERATED_foo -> (let fc_FEDERATED_symbol_0=FEDERATED_foo.f(FEDERATED_foo.x),fc_FEDERATED_symbol_1=FEDERATED_foo.f(fc_FEDERATED_symbol_0) in fc_FEDERATED_symbol_1))'
    ), ('nested_fn_different', lambda f, g, x: f(g(x)),
        computation_types.StructType(
            [('f', computation_types.FunctionType(tf.int32, tf.int32)),
             ('g', computation_types.FunctionType(tf.int32, tf.int32)),
             ('x', tf.int32)]),
        '(FEDERATED_foo -> (let fc_FEDERATED_symbol_0=FEDERATED_foo.g(FEDERATED_foo.x),fc_FEDERATED_symbol_1=FEDERATED_foo.f(fc_FEDERATED_symbol_0) in fc_FEDERATED_symbol_1))'
        ), ('selection', lambda x:
            (x[1], x[0]), computation_types.StructType([tf.int32, tf.int32]),
            '(FEDERATED_foo -> <FEDERATED_foo[1],FEDERATED_foo[0]>)'),
                                    ('constant', lambda: 'stuff', None,
                                     '( -> (let fc_FEDERATED_symbol_0=comp#'))
    # pyformat: enable
    def test_returns_result(self, fn, parameter_type, fn_str):
        parameter_name = 'foo' if parameter_type is not None else None
        result, _ = _federated_computation_serializer(fn, parameter_name,
                                                      parameter_type)
        self.assertStartsWith(str(result), fn_str)

    # pyformat: disable
    @parameterized.named_parameters(
        ('tuple', lambda x:
         (x[1], x[0]), computation_types.StructType([tf.int32, tf.float32]),
         computation_types.StructWithPythonType([(None, tf.float32),
                                                 (None, tf.int32)], tuple)),
        ('list', lambda x: [x[1], x[0]],
         computation_types.StructType([tf.int32, tf.float32]),
         computation_types.StructWithPythonType([(None, tf.float32),
                                                 (None, tf.int32)], list)),
        ('odict', lambda x: collections.OrderedDict([('A', x[1]),
                                                     ('B', x[0])]),
         computation_types.StructType([tf.int32, tf.float32]),
         computation_types.StructWithPythonType([('A', tf.float32),
                                                 ('B', tf.int32)],
                                                collections.OrderedDict)),
        ('namedtuple', lambda x: TestNamedTuple(x=x[1], y=x[0]),
         computation_types.StructType([tf.int32, tf.float32]),
         computation_types.StructWithPythonType([('x', tf.float32),
                                                 ('y', tf.int32)],
                                                TestNamedTuple)),
    )
    # pyformat: enable
    def test_returns_result_with_py_container(self, fn, parameter_type,
                                              exepcted_result_type):
        _, type_signature = _federated_computation_serializer(
            fn, 'foo', parameter_type)
        self.assertIs(type(type_signature.result), type(exepcted_result_type))
        self.assertIs(type_signature.result.python_container,
                      exepcted_result_type.python_container)
        self.assertEqual(type_signature.result, exepcted_result_type)
예제 #7
0
def _parameter_type(
    parameters, parameter_types: Tuple[computation_types.Type, ...]
) -> Optional[computation_types.Type]:
    """Bundle any user-provided parameter types into a single argument type."""
    parameter_names = [parameter.name for parameter in parameters]
    if not parameter_types and not parameters:
        return None
    if len(parameter_types) == 1:
        parameter_type = parameter_types[0]
        if parameter_type is None and not parameters:
            return None
        if len(parameters) == 1:
            return parameter_type
        # There is a single parameter type but multiple parameters.
        if not parameter_type.is_struct() or len(parameter_type) != len(
                parameters):
            raise TypeError(
                f'Function with {len(parameters)} parameters must have a parameter '
                f'type with the same number of parameters. Found parameter type '
                f'{parameter_type}.')
        name_list_from_types = structure.name_list(parameter_type)
        if name_list_from_types:
            if len(name_list_from_types) != len(parameter_type):
                raise TypeError(
                    'Types with both named and unnamed fields cannot be unpacked into '
                    f'argument lists. Found parameter type {parameter_type}.')
            if set(name_list_from_types) != set(parameter_names):
                raise TypeError(
                    'Function argument names must match field names of parameter type. '
                    f'Found argument names {parameter_names}, which do not match '
                    f'{name_list_from_types}, the top-level fields of the parameter '
                    f'type {parameter_type}.')
            # The provided parameter type has all named fields which exactly match
            # the names of the function's parameters.
            return parameter_type
        else:
            # The provided parameter type has no named fields. Apply the names from
            # the function parameters.
            parameter_types = (
                v for (_, v) in structure.to_elements(parameter_type))
            return computation_types.StructWithPythonType(
                list(zip(parameter_names, parameter_types)),
                collections.OrderedDict)
    elif len(parameters) == 1:
        # If there are multiple provided argument types but the function being
        # decorated only accepts a single argument, tuple the arguments together.
        return computation_types.to_type(parameter_types)
    if len(parameters) != len(parameter_types):
        raise TypeError(
            f'Function with {len(parameters)} parameters is '
            f'incompatible with provided argument types {parameter_types}.')
    # The function has `n` parameters and `n` parameter types.
    # Zip them up into a structure using the names from the function as keys.
    return computation_types.StructWithPythonType(
        list(zip(parameter_names, parameter_types)), collections.OrderedDict)
예제 #8
0
 def test_anon_tuple_with_names_to_container_without_names_fails(self):
   anon_tuple = structure.Struct([(None, 1), ('a', 2.0)])
   types = [tf.int32, tf.float32]
   with self.assertRaisesRegex(ValueError,
                               'contains a mix of named and unnamed elements'):
     type_conversions.type_to_py_container(
         anon_tuple, computation_types.StructWithPythonType(types, tuple))
   anon_tuple = structure.Struct([('a', 1), ('b', 2.0)])
   with self.assertRaisesRegex(ValueError, 'which does not support names'):
     type_conversions.type_to_py_container(
         anon_tuple, computation_types.StructWithPythonType(types, list))
예제 #9
0
 def test_anon_tuple_without_names_to_container_without_names(self):
   anon_tuple = structure.Struct([(None, 1), (None, 2.0)])
   types = [tf.int32, tf.float32]
   self.assertSequenceEqual(
       type_conversions.type_to_py_container(
           anon_tuple, computation_types.StructWithPythonType(types, list)),
       [1, 2.0])
   self.assertSequenceEqual(
       type_conversions.type_to_py_container(
           anon_tuple, computation_types.StructWithPythonType(types, tuple)),
       (1, 2.0))
예제 #10
0
 def test_with_ragged_tensor(self):
     t = type_conversions.infer_type(
         tf.RaggedTensor.from_row_splits([0, 0, 0, 0], [0, 1, 4]))
     self.assert_types_identical(
         t,
         computation_types.StructWithPythonType([
             ('flat_values', computation_types.TensorType(tf.int32, [4])),
             ('nested_row_splits',
              computation_types.StructWithPythonType([
                  (None, computation_types.TensorType(tf.int64, [3]))
              ], tuple)),
         ], tf.RaggedTensor))
예제 #11
0
 def test_returns_model_weights_for_model_callable(self):
   weights_type = model_utils.weights_type_from_model(TestModel)
   self.assertEqual(
       computation_types.StructWithPythonType(
           [('trainable',
             computation_types.StructWithPythonType([
                 computation_types.TensorType(tf.float32, [3]),
                 computation_types.TensorType(tf.float32, [1]),
             ], list)),
            ('non_trainable',
             computation_types.StructWithPythonType([
                 computation_types.TensorType(tf.int32),
             ], list))], model_utils.ModelWeights), weights_type)
 def test_transforms_unnamed_tuple_type_preserving_tuple_container(self):
   orig_type = computation_types.StructWithPythonType([tf.int32, tf.float64],
                                                      tuple)
   expected_type = computation_types.StructWithPythonType(
       [tf.float32, tf.float32], tuple)
   result_type, mutated = type_transformations.transform_type_postorder(
       orig_type, _convert_tensor_to_float)
   noop_type, not_mutated = type_transformations.transform_type_postorder(
       orig_type, _convert_abstract_type_to_tensor)
   self.assertEqual(result_type, expected_type)
   self.assertEqual(noop_type, orig_type)
   self.assertTrue(mutated)
   self.assertFalse(not_mutated)
예제 #13
0
 def test_capture_result_with_ragged_tensor(self):
   with tf.Graph().as_default() as graph:
     type_spec, binding = tensorflow_utils.capture_result_from_graph(
         tf.RaggedTensor.from_row_splits([0, 0, 0, 0], [0, 1, 4]), graph)
     del binding
     self.assert_types_identical(
         type_spec,
         computation_types.StructWithPythonType([
             ('flat_values', computation_types.TensorType(tf.int32, [4])),
             ('nested_row_splits',
              computation_types.StructWithPythonType([
                  (None, computation_types.TensorType(tf.int64, [3]))
              ], tuple)),
         ], tf.RaggedTensor))
예제 #14
0
 def test_succeeds_with_federated_namedtupletype(self):
     anon_tuple = anonymous_tuple.AnonymousTuple([(None, 1), (None, 2.0)])
     types = [tf.int32, tf.float32]
     self.assertSequenceEqual(
         type_conversions.type_to_py_container(
             anon_tuple,
             computation_types.FederatedType(
                 computation_types.StructWithPythonType(types, list),
                 placement_literals.SERVER)), [1, 2.0])
     self.assertSequenceEqual(
         type_conversions.type_to_py_container(
             anon_tuple,
             computation_types.FederatedType(
                 computation_types.StructWithPythonType(types, tuple),
                 placement_literals.SERVER)), (1, 2.0))
예제 #15
0
 def test_not_anon_tuple_passthrough(self):
     value = (1, 2.0)
     result = type_conversions.type_to_py_container(
         (1, 2.0),
         computation_types.StructWithPythonType([tf.int32, tf.float32],
                                                container_type=list))
     self.assertEqual(result, value)
 def test_ordered_dict(self):
     t = computation_types.StructWithPythonType([('a', tf.int32)],
                                                collections.OrderedDict)
     self.assertIs(t.python_container, collections.OrderedDict)
     self.assertEqual(
         repr(t),
         'StructType([(\'a\', TensorType(tf.int32))]) as OrderedDict')
예제 #17
0
 def test_py_named_tuple(self):
   py_named_tuple_type = collections.namedtuple('test_tuple', ['a'])
   t = computation_types.StructWithPythonType([('a', tf.int32)],
                                              py_named_tuple_type)
   self.assertIs(t.python_container, py_named_tuple_type)
   self.assertEqual(
       repr(t), 'StructType([(\'a\', TensorType(tf.int32))]) as test_tuple')
예제 #18
0
 def test_ragged_tensor(self):
     value = structure.Struct([
         ('flat_values', [0, 0, 0, 0]),
         ('nested_row_splits', [[0, 1, 4]]),
     ])
     value_type = computation_types.StructWithPythonType([
         ('flat_values', computation_types.TensorType(tf.int32, [4])),
         ('nested_row_splits',
          computation_types.StructWithPythonType([
              (None, computation_types.TensorType(tf.int64, [3]))
          ], tuple)),
     ], tf.RaggedTensor)
     result = type_conversions.type_to_py_container(value, value_type)
     self.assertIsInstance(result, tf.RaggedTensor)
     self.assertAllEqual(result.flat_values, [0, 0, 0, 0])
     self.assertEqual(len(result.nested_row_splits), 1)
     self.assertAllEqual(result.nested_row_splits[0], [0, 1, 4])
예제 #19
0
 def test_struct_with_container_type(self):
   x = building_blocks.Reference('foo', tf.int32)
   y = building_blocks.Reference('bar', tf.bool)
   z = building_blocks.Struct([x, ('y', y)], tuple)
   self.assertEqual(
       z.type_signature,
       computation_types.StructWithPythonType([tf.int32, ('y', tf.bool)],
                                              tuple))
    def test_py_attr_class(self):
        @attr.s
        class TestFoo(object):
            a = attr.ib()

        t = computation_types.StructWithPythonType([('a', tf.int32)], TestFoo)
        self.assertIs(t.python_container, TestFoo)
        self.assertEqual(
            repr(t), 'StructType([(\'a\', TensorType(tf.int32))]) as TestFoo')
 def test_serialize_deserialize_named_tuple_types_py_container(self):
     # The Py container is destroyed during ser/de.
     with_container = computation_types.StructWithPythonType(
         (tf.int32, tf.bool), tuple)
     p1 = type_serialization.serialize_type(with_container)
     without_container = type_serialization.deserialize_type(p1)
     self.assertNotEqual(with_container, without_container)  # Not equal.
     self.assertIsInstance(without_container, computation_types.StructType)
     self.assertNotIsInstance(without_container,
                              computation_types.StructWithPythonType)
     with_container.check_equivalent_to(without_container)
예제 #22
0
 def test_client_placed_tuple(self):
   value = [
       structure.Struct([(None, 1), (None, 2)]),
       structure.Struct([(None, 3), (None, 4)])
   ]
   type_spec = computation_types.FederatedType(
       computation_types.StructWithPythonType([(None, tf.int32),
                                               (None, tf.int32)], tuple),
       placement_literals.CLIENTS)
   self.assertEqual([(1, 2), (3, 4)],
                    type_conversions.type_to_py_container(value, type_spec))
예제 #23
0
 def transform_to_tff_known_type(
     type_spec: computation_types.Type) -> Tuple[computation_types.Type, bool]:
   """Transforms `StructType` to `StructWithPythonType`."""
   if type_spec.is_struct() and not type_spec.is_struct_with_python():
     field_is_named = tuple(
         name is not None for name, _ in structure.iter_elements(type_spec))
     has_names = any(field_is_named)
     is_all_named = all(field_is_named)
     if is_all_named:
       return computation_types.StructWithPythonType(
           elements=structure.iter_elements(type_spec),
           container_type=collections.OrderedDict), True
     elif not has_names:
       return computation_types.StructWithPythonType(
           elements=structure.iter_elements(type_spec),
           container_type=tuple), True
     else:
       raise TypeError('Cannot represent TFF type in TF because it contains '
                       f'partially named structures. Type: {type_spec}')
   return type_spec, False
예제 #24
0
  def test_takes_namedtuple_polymorphic(self):
    MyType = collections.namedtuple('MyType', ['x', 'y'])  # pylint: disable=invalid-name

    @tf.function
    def foo(t):
      self.assertIsInstance(t, MyType)
      return t.x + t.y

    foo = computation_wrapper_instances.tensorflow_wrapper(foo)

    concrete_fn = foo.fn_for_argument_type(
        computation_types.StructWithPythonType([('x', tf.int32),
                                                ('y', tf.int32)], MyType))
    self.assertEqual(concrete_fn.type_signature.compact_representation(),
                     '(<x=int32,y=int32> -> int32)')
    concrete_fn = foo.fn_for_argument_type(
        computation_types.StructWithPythonType([('x', tf.float32),
                                                ('y', tf.float32)], MyType))
    self.assertEqual(concrete_fn.type_signature.compact_representation(),
                     '(<x=float32,y=float32> -> float32)')
예제 #25
0
 def test_with_two_level_tuple(self):
     type_signature = computation_types.StructWithPythonType([
         ('a', tf.bool),
         ('b',
          computation_types.StructWithPythonType([
              ('c', computation_types.TensorType(tf.float32)),
              ('d', computation_types.TensorType(tf.int32, [20])),
          ], collections.OrderedDict)),
         ('e', computation_types.StructType([])),
     ], collections.OrderedDict)
     tensor_specs = type_conversions.type_to_tf_tensor_specs(type_signature)
     self.assert_nested_struct_eq(
         tensor_specs, {
             'a': tf.TensorSpec([], tf.bool),
             'b': {
                 'c': tf.TensorSpec([], tf.float32),
                 'd': tf.TensorSpec([20], tf.int32)
             },
             'e': (),
         })
예제 #26
0
 def test_stamp_parameter_in_graph_with_struct_with_python_type(self):
   with tf.Graph().as_default() as my_graph:
     x = self._checked_stamp_parameter(
         'foo',
         computation_types.StructWithPythonType([('a', tf.int32),
                                                 ('b', tf.bool)],
                                                collections.OrderedDict))
   self.assertIsInstance(x, structure.Struct)
   self.assertTrue(len(x), 2)
   self._assert_is_placeholder(x.a, 'foo_a:0', tf.int32, [], my_graph)
   self._assert_is_placeholder(x.b, 'foo_b:0', tf.bool, [], my_graph)
예제 #27
0
 def test_capture_result_with_sparse_tensor(self):
   with tf.Graph().as_default() as graph:
     type_spec, binding = tensorflow_utils.capture_result_from_graph(
         tf.SparseTensor(indices=[[1]], values=[2], dense_shape=[5]), graph)
     del binding
     self.assert_types_identical(
         type_spec,
         computation_types.StructWithPythonType([
             ('indices', computation_types.TensorType(tf.int64, [1, 1])),
             ('values', computation_types.TensorType(tf.int32, [1])),
             ('dense_shape', computation_types.TensorType(tf.int64, [1])),
         ], tf.SparseTensor))
예제 #28
0
 def test_with_sparse_tensor(self):
     # sparse_tensor = [0, 2, 0, 0, 0]
     sparse_tensor = tf.SparseTensor(indices=[[1]],
                                     values=[2],
                                     dense_shape=[5])
     t = type_conversions.infer_type(sparse_tensor)
     self.assert_types_identical(
         t,
         computation_types.StructWithPythonType([
             ('indices', computation_types.TensorType(tf.int64, [1, 1])),
             ('values', computation_types.TensorType(tf.int32, [1])),
             ('dense_shape', computation_types.TensorType(tf.int64, [1])),
         ], tf.SparseTensor))
예제 #29
0
 def test_with_tensor_triple(self):
     type_signature = computation_types.StructWithPythonType([
         ('a', computation_types.TensorType(tf.int32, [5])),
         ('b', computation_types.TensorType(tf.bool)),
         ('c', computation_types.TensorType(tf.float32, [3])),
     ], collections.OrderedDict)
     tensor_specs = type_conversions.type_to_tf_tensor_specs(type_signature)
     test.assert_nested_struct_eq(
         tensor_specs, {
             'a': tf.TensorSpec([5], tf.int32),
             'b': tf.TensorSpec([], tf.bool),
             'c': tf.TensorSpec([3], tf.float32)
         })
예제 #30
0
 def test_without_names(self):
     expected_structure = (
         tf.TensorSpec(shape=(), dtype=tf.bool),
         tf.TensorSpec(shape=(), dtype=tf.int32),
     )
     type_spec = computation_types.StructWithPythonType(
         expected_structure, tuple)
     tf_structure = type_conversions.type_to_tf_structure(type_spec)
     with tf.Graph().as_default():
         ds = tf.data.experimental.from_variant(tf.compat.v1.placeholder(
             tf.variant, shape=[]),
                                                structure=tf_structure)
         actual_structure = ds.element_spec
         self.assertEqual(expected_structure, actual_structure)