def test_multiple_named_and_unnamed(self): v = [(None, 10), ('foo', 20), ('bar', 30)] x = structure.Struct(v) self.assertLen(x, 3) self.assertEqual(x[0], 10) self.assertEqual(x[1], 20) self.assertEqual(x[2], 30) self.assertRaises(IndexError, lambda _: x[3], None) self.assertEqual(list(iter(x)), [10, 20, 30]) self.assertEqual(dir(x), ['bar', 'foo']) self.assertEqual(structure.name_list(x), ['foo', 'bar']) self.assertEqual(x.foo, 20) self.assertEqual(x.bar, 30) self.assertRaises(AttributeError, lambda _: x.baz, None) self.assertEqual(x, structure.Struct([(None, 10), ('foo', 20), ('bar', 30)])) self.assertNotEqual( x, structure.Struct([('foo', 10), ('bar', 20), (None, 30)])) self.assertEqual(structure.to_elements(x), v) self.assertEqual( repr(x), 'Struct([(None, 10), (\'foo\', 20), (\'bar\', 30)])') self.assertEqual(str(x), '<10,foo=20,bar=30>') with self.assertRaisesRegex(ValueError, 'unnamed'): structure.to_odict(x) with self.assertRaisesRegex(ValueError, 'named and unnamed'): structure.to_odict_or_tuple(x)
def check_finalizers_matches_unfinalized_metrics( metric_finalizers: model_lib.MetricFinalizersType, local_unfinalized_metrics_type: computation_types.StructWithPythonType ): """Verifies that compatibility of variables and finalizers. Args: metric_finalizers: The finalizers to validate. local_unfinalized_metrics_type: The unfinalized metrics type to validate. Raises: ValueError: If `metric_finalizers` cannot finalize a variable structure with type `local_unfinalized_metrics_type`. """ metric_names_in_metric_finalizers = set(metric_finalizers.keys()) metric_names_in_local_unfinalized_metrics = set( structure.name_list(local_unfinalized_metrics_type)) if (metric_names_in_metric_finalizers != metric_names_in_local_unfinalized_metrics): difference_1 = (metric_names_in_metric_finalizers - metric_names_in_local_unfinalized_metrics) difference_2 = (metric_names_in_local_unfinalized_metrics - metric_names_in_metric_finalizers) raise ValueError( 'The metric names in `metric_finalizers` do not match those in the ' '`local_unfinalized_metrics`. Metric names in the `metric_finalizers`' f'but not the `local_unfinalized_metrics`: {difference_1}. ' 'Metric names in the `local_unfinalized_metrics` but not the ' f'`metric_finalizers`: {difference_2}.\n' 'Metrics names in the `metric_finalizers`: ' f'{metric_names_in_metric_finalizers}. Metric names in the ' '`local_unfinalized_metrics`: ' f'{metric_names_in_local_unfinalized_metrics}.')
def _parameter_type( parameters, parameter_types: Tuple[computation_types.Type, ...] ) -> Optional[computation_types.Type]: """Bundle any user-provided parameter types into a single argument type.""" parameter_names = [parameter.name for parameter in parameters] if not parameter_types and not parameters: return None if len(parameter_types) == 1: parameter_type = parameter_types[0] if parameter_type is None and not parameters: return None if len(parameters) == 1: return parameter_type # There is a single parameter type but multiple parameters. if not parameter_type.is_struct() or len(parameter_type) != len( parameters): raise TypeError( f'Function with {len(parameters)} parameters must have a parameter ' f'type with the same number of parameters. Found parameter type ' f'{parameter_type}.') name_list_from_types = structure.name_list(parameter_type) if name_list_from_types: if len(name_list_from_types) != len(parameter_type): raise TypeError( 'Types with both named and unnamed fields cannot be unpacked into ' f'argument lists. Found parameter type {parameter_type}.') if set(name_list_from_types) != set(parameter_names): raise TypeError( 'Function argument names must match field names of parameter type. ' f'Found argument names {parameter_names}, which do not match ' f'{name_list_from_types}, the top-level fields of the parameter ' f'type {parameter_type}.') # The provided parameter type has all named fields which exactly match # the names of the function's parameters. return parameter_type else: # The provided parameter type has no named fields. Apply the names from # the function parameters. parameter_types = ( v for (_, v) in structure.to_elements(parameter_type)) return computation_types.StructWithPythonType( list(zip(parameter_names, parameter_types)), collections.OrderedDict) elif len(parameters) == 1: # If there are multiple provided argument types but the function being # decorated only accepts a single argument, tuple the arguments together. return computation_types.to_type(parameter_types) if len(parameters) != len(parameter_types): raise TypeError( f'Function with {len(parameters)} parameters is ' f'incompatible with provided argument types {parameter_types}.') # The function has `n` parameters and `n` parameter types. # Zip them up into a structure using the names from the function as keys. return computation_types.StructWithPythonType( list(zip(parameter_names, parameter_types)), collections.OrderedDict)
def __getattr__(self, name): py_typecheck.check_type(name, str) _check_struct_or_federated_struct(self, name) if _is_federated_named_tuple(self): if name not in structure.name_list(self.type_signature.member): raise AttributeError( 'There is no such attribute \'{}\' in this federated tuple. Valid ' 'attributes: ({})'.format( name, ', '.join(dir(self.type_signature.member)))) return Value( building_block_factory.create_federated_getattr_call( self._comp, name)) if name not in dir(self.type_signature): raise AttributeError( 'There is no such attribute \'{}\' in this tuple. Valid attributes: ({})' .format(name, ', '.join(dir(self.type_signature)))) if self._comp.is_struct(): return Value(getattr(self._comp, name)) return Value(building_blocks.Selection(self._comp, name=name))
def _check_ordereddict_container_for_struct(type_to_check): if not type_to_check.is_struct(): return type_to_check, False # We can't use `dir` here, since it sorts the names before returning. We # also must filter to names which are actually present. names_in_sequence_order = structure.name_list(type_to_check) names_are_sorted = _names_are_in_sorted_order(names_in_sequence_order) has_no_names = not bool(names_in_sequence_order) if has_no_names or (names_in_sequence_order and names_are_sorted): # If alphabetical order matches sequence order, TFF's deserialization will # traverse the structure correctly; there is no ambiguity here. On the # other hand, if there are no names, sequence order is the only method of # traversal, so there is no ambiguity here either. return type_to_check, False elif not type_to_check.is_struct_with_python(): raise ValueError( 'Attempting to serialize a named struct type with ' 'ambiguous traversal order (sequence order distinct ' 'from alphabetical order) without a Python container; ' 'this is an unsafe operation, as TFF cannot determine ' 'the intended traversal order after deserializing the ' 'proto due to inconsistent behavior of tf.nest.') container_type = computation_types.StructWithPythonType.get_container_type( type_to_check) if (not names_are_sorted ) and container_type is not collections.OrderedDict: raise ValueError( 'Attempted to serialize a dataset yielding named ' 'elements in non-sorted sequence order with ' f'non-OrderedDict container (type {container_type}). ' 'This is an ambiguous operation; `tf.nest` behaves in ' 'a manner which depends on the Python type of this ' 'container, so coercing the dataset reconstructed ' 'from the resulting Value proto depends on assuming a ' 'single Python type here. Please prefer to use ' '`collections.OrderedDict` containers for the elements ' 'your dataset yields.') return type_to_check, False
def test_create_struct(self): elements = [] for x in [10, 20]: elements.append(('v{}'.format(x), _run_sync( self._sequence_executor.create_value( x, computation_types.TensorType(tf.int32))))) elements = structure.Struct(elements) val = _run_sync(self._sequence_executor.create_struct(elements)) self.assertIsInstance(val, sequence_executor.SequenceExecutorValue) self.assertEqual(str(val.type_signature), '<v10=int32,v20=int32>') self.assertIsInstance(val.internal_representation, structure.Struct) self.assertListEqual(structure.name_list(val.internal_representation), ['v10', 'v20']) self.assertIsInstance(val.internal_representation.v10, eager_tf_executor.EagerValue) self.assertIsInstance(val.internal_representation.v20, eager_tf_executor.EagerValue) self.assertEqual(_run_sync(val.internal_representation.v10.compute()), 10) self.assertEqual(_run_sync(val.internal_representation.v20.compute()), 20)