Пример #1
0
 def test_multiple_named_and_unnamed(self):
   v = [(None, 10), ('foo', 20), ('bar', 30)]
   x = structure.Struct(v)
   self.assertLen(x, 3)
   self.assertEqual(x[0], 10)
   self.assertEqual(x[1], 20)
   self.assertEqual(x[2], 30)
   self.assertRaises(IndexError, lambda _: x[3], None)
   self.assertEqual(list(iter(x)), [10, 20, 30])
   self.assertEqual(dir(x), ['bar', 'foo'])
   self.assertEqual(structure.name_list(x), ['foo', 'bar'])
   self.assertEqual(x.foo, 20)
   self.assertEqual(x.bar, 30)
   self.assertRaises(AttributeError, lambda _: x.baz, None)
   self.assertEqual(x, structure.Struct([(None, 10), ('foo', 20),
                                         ('bar', 30)]))
   self.assertNotEqual(
       x, structure.Struct([('foo', 10), ('bar', 20), (None, 30)]))
   self.assertEqual(structure.to_elements(x), v)
   self.assertEqual(
       repr(x), 'Struct([(None, 10), (\'foo\', 20), (\'bar\', 30)])')
   self.assertEqual(str(x), '<10,foo=20,bar=30>')
   with self.assertRaisesRegex(ValueError, 'unnamed'):
     structure.to_odict(x)
   with self.assertRaisesRegex(ValueError, 'named and unnamed'):
     structure.to_odict_or_tuple(x)
Пример #2
0
def check_finalizers_matches_unfinalized_metrics(
        metric_finalizers: model_lib.MetricFinalizersType,
        local_unfinalized_metrics_type: computation_types.StructWithPythonType
):
    """Verifies that compatibility of variables and finalizers.

  Args:
    metric_finalizers: The finalizers to validate.
    local_unfinalized_metrics_type: The unfinalized metrics type to validate.

  Raises:
    ValueError: If `metric_finalizers` cannot finalize a variable structure
      with type `local_unfinalized_metrics_type`.
  """
    metric_names_in_metric_finalizers = set(metric_finalizers.keys())
    metric_names_in_local_unfinalized_metrics = set(
        structure.name_list(local_unfinalized_metrics_type))
    if (metric_names_in_metric_finalizers !=
            metric_names_in_local_unfinalized_metrics):
        difference_1 = (metric_names_in_metric_finalizers -
                        metric_names_in_local_unfinalized_metrics)
        difference_2 = (metric_names_in_local_unfinalized_metrics -
                        metric_names_in_metric_finalizers)
        raise ValueError(
            'The metric names in `metric_finalizers` do not match those in the '
            '`local_unfinalized_metrics`. Metric names in the `metric_finalizers`'
            f'but not the `local_unfinalized_metrics`: {difference_1}. '
            'Metric names in the `local_unfinalized_metrics` but not the '
            f'`metric_finalizers`: {difference_2}.\n'
            'Metrics names in the `metric_finalizers`: '
            f'{metric_names_in_metric_finalizers}. Metric names in the '
            '`local_unfinalized_metrics`: '
            f'{metric_names_in_local_unfinalized_metrics}.')
Пример #3
0
def _parameter_type(
    parameters, parameter_types: Tuple[computation_types.Type, ...]
) -> Optional[computation_types.Type]:
    """Bundle any user-provided parameter types into a single argument type."""
    parameter_names = [parameter.name for parameter in parameters]
    if not parameter_types and not parameters:
        return None
    if len(parameter_types) == 1:
        parameter_type = parameter_types[0]
        if parameter_type is None and not parameters:
            return None
        if len(parameters) == 1:
            return parameter_type
        # There is a single parameter type but multiple parameters.
        if not parameter_type.is_struct() or len(parameter_type) != len(
                parameters):
            raise TypeError(
                f'Function with {len(parameters)} parameters must have a parameter '
                f'type with the same number of parameters. Found parameter type '
                f'{parameter_type}.')
        name_list_from_types = structure.name_list(parameter_type)
        if name_list_from_types:
            if len(name_list_from_types) != len(parameter_type):
                raise TypeError(
                    'Types with both named and unnamed fields cannot be unpacked into '
                    f'argument lists. Found parameter type {parameter_type}.')
            if set(name_list_from_types) != set(parameter_names):
                raise TypeError(
                    'Function argument names must match field names of parameter type. '
                    f'Found argument names {parameter_names}, which do not match '
                    f'{name_list_from_types}, the top-level fields of the parameter '
                    f'type {parameter_type}.')
            # The provided parameter type has all named fields which exactly match
            # the names of the function's parameters.
            return parameter_type
        else:
            # The provided parameter type has no named fields. Apply the names from
            # the function parameters.
            parameter_types = (
                v for (_, v) in structure.to_elements(parameter_type))
            return computation_types.StructWithPythonType(
                list(zip(parameter_names, parameter_types)),
                collections.OrderedDict)
    elif len(parameters) == 1:
        # If there are multiple provided argument types but the function being
        # decorated only accepts a single argument, tuple the arguments together.
        return computation_types.to_type(parameter_types)
    if len(parameters) != len(parameter_types):
        raise TypeError(
            f'Function with {len(parameters)} parameters is '
            f'incompatible with provided argument types {parameter_types}.')
    # The function has `n` parameters and `n` parameter types.
    # Zip them up into a structure using the names from the function as keys.
    return computation_types.StructWithPythonType(
        list(zip(parameter_names, parameter_types)), collections.OrderedDict)
Пример #4
0
    def __getattr__(self, name):
        py_typecheck.check_type(name, str)
        _check_struct_or_federated_struct(self, name)
        if _is_federated_named_tuple(self):
            if name not in structure.name_list(self.type_signature.member):
                raise AttributeError(
                    'There is no such attribute \'{}\' in this federated tuple. Valid '
                    'attributes: ({})'.format(
                        name, ', '.join(dir(self.type_signature.member))))

            return Value(
                building_block_factory.create_federated_getattr_call(
                    self._comp, name))
        if name not in dir(self.type_signature):
            raise AttributeError(
                'There is no such attribute \'{}\' in this tuple. Valid attributes: ({})'
                .format(name, ', '.join(dir(self.type_signature))))
        if self._comp.is_struct():
            return Value(getattr(self._comp, name))
        return Value(building_blocks.Selection(self._comp, name=name))
Пример #5
0
    def _check_ordereddict_container_for_struct(type_to_check):
        if not type_to_check.is_struct():
            return type_to_check, False
        # We can't use `dir` here, since it sorts the names before returning. We
        # also must filter to names which are actually present.
        names_in_sequence_order = structure.name_list(type_to_check)
        names_are_sorted = _names_are_in_sorted_order(names_in_sequence_order)
        has_no_names = not bool(names_in_sequence_order)
        if has_no_names or (names_in_sequence_order and names_are_sorted):
            # If alphabetical order matches sequence order, TFF's deserialization will
            # traverse the structure correctly; there is no ambiguity here. On the
            # other hand, if there are no names, sequence order is the only method of
            # traversal, so there is no ambiguity here either.
            return type_to_check, False
        elif not type_to_check.is_struct_with_python():
            raise ValueError(
                'Attempting to serialize a named struct type with '
                'ambiguous traversal order (sequence order distinct '
                'from alphabetical order) without a Python container; '
                'this is an unsafe operation, as TFF cannot determine '
                'the intended traversal order after deserializing the '
                'proto due to inconsistent behavior of tf.nest.')

        container_type = computation_types.StructWithPythonType.get_container_type(
            type_to_check)
        if (not names_are_sorted
            ) and container_type is not collections.OrderedDict:
            raise ValueError(
                'Attempted to serialize a dataset yielding named '
                'elements in non-sorted sequence order with '
                f'non-OrderedDict container (type {container_type}). '
                'This is an ambiguous operation; `tf.nest` behaves in '
                'a manner which depends on the Python type of this '
                'container, so coercing the dataset reconstructed '
                'from the resulting Value proto depends on assuming a '
                'single Python type here. Please prefer to use '
                '`collections.OrderedDict` containers for the elements '
                'your dataset yields.')
        return type_to_check, False
Пример #6
0
 def test_create_struct(self):
     elements = []
     for x in [10, 20]:
         elements.append(('v{}'.format(x),
                          _run_sync(
                              self._sequence_executor.create_value(
                                  x,
                                  computation_types.TensorType(tf.int32)))))
     elements = structure.Struct(elements)
     val = _run_sync(self._sequence_executor.create_struct(elements))
     self.assertIsInstance(val, sequence_executor.SequenceExecutorValue)
     self.assertEqual(str(val.type_signature), '<v10=int32,v20=int32>')
     self.assertIsInstance(val.internal_representation, structure.Struct)
     self.assertListEqual(structure.name_list(val.internal_representation),
                          ['v10', 'v20'])
     self.assertIsInstance(val.internal_representation.v10,
                           eager_tf_executor.EagerValue)
     self.assertIsInstance(val.internal_representation.v20,
                           eager_tf_executor.EagerValue)
     self.assertEqual(_run_sync(val.internal_representation.v10.compute()),
                      10)
     self.assertEqual(_run_sync(val.internal_representation.v20.compute()),
                      20)