コード例 #1
def type_to_tf_structure(type_spec: computation_types.Type):
    """Returns nested `tf.data.experimental.Structure` for a given TFF type.

    type_spec: A `computation_types.Type`, the type specification must be
      composed of only named tuples and tensors. In all named tuples that appear
      in the type spec, all the elements must be named.

    An instance of `tf.data.experimental.Structure`, possibly nested, that
    corresponds to `type_spec`.

    ValueError: if the `type_spec` is composed of something other than named
      tuples and tensors, or if any of the elements in named tuples are unnamed.
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return tf.TensorSpec(type_spec.shape, type_spec.dtype)
    elif type_spec.is_tuple():
        elements = anonymous_tuple.to_elements(type_spec)
        if not elements:
            raise ValueError('Empty tuples are unsupported.')
        element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements]
        named = element_outputs[0][0] is not None
        if not all((e[0] is not None) == named for e in element_outputs):
            raise ValueError('Tuple elements inconsistently named.')
        if not type_spec.is_tuple_with_py_container():
            if named:
                output = collections.OrderedDict(element_outputs)
                output = tuple(v for _, v in element_outputs)
            container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type(
            if (py_typecheck.is_named_tuple(container_type)
                    or py_typecheck.is_attrs(container_type)):
                output = container_type(**dict(element_outputs))
            elif named:
                output = container_type(element_outputs)
                output = container_type(e if e[0] is not None else e[1]
                                        for e in element_outputs)
        return output
        raise ValueError('Unsupported type {}.'.format(
コード例 #2
def transform_type_postorder(
    type_signature: computation_types.Type,
    transform_fn: Callable[[computation_types.Type],
                           Tuple[computation_types.Type, bool]]):
    """Walks type tree of `type_signature` postorder, calling `transform_fn`.

    type_signature: Instance of `computation_types.Type` to transform
    transform_fn: Transformation function to apply to each node in the type tree
      of `type_signature`. Must be instance of Python function type.

    A possibly transformed version of `type_signature`, with each node in its
    tree the result of applying `transform_fn` to the corresponding node in

    TypeError: If the types don't match the specification above.
    py_typecheck.check_type(type_signature, computation_types.Type)
    if type_signature.is_federated():
        transformed_member, member_mutated = transform_type_postorder(
            type_signature.member, transform_fn)
        if member_mutated:
            type_signature = computation_types.FederatedType(
                transformed_member, type_signature.placement,
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or member_mutated
    elif type_signature.is_sequence():
        transformed_element, element_mutated = transform_type_postorder(
            type_signature.element, transform_fn)
        if element_mutated:
            type_signature = computation_types.SequenceType(
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or element_mutated
    elif type_signature.is_function():
        if type_signature.parameter is not None:
            transformed_parameter, parameter_mutated = transform_type_postorder(
                type_signature.parameter, transform_fn)
            transformed_parameter, parameter_mutated = (None, False)
        transformed_result, result_mutated = transform_type_postorder(
            type_signature.result, transform_fn)
        if parameter_mutated or result_mutated:
            type_signature = computation_types.FunctionType(
                transformed_parameter, transformed_result)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, (type_signature_mutated or parameter_mutated
                                or result_mutated)
    elif type_signature.is_tuple():
        elements = []
        elements_mutated = False
        for element in anonymous_tuple.iter_elements(type_signature):
            transformed_element, element_mutated = transform_type_postorder(
                element[1], transform_fn)
            elements_mutated = elements_mutated or element_mutated
            elements.append((element[0], transformed_element))
        if elements_mutated:
            if type_signature.is_tuple_with_py_container():
                container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type(
                type_signature = computation_types.NamedTupleTypeWithPyContainerType(
                    elements, container_type)
                type_signature = computation_types.NamedTupleType(elements)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or elements_mutated
    elif type_signature.is_abstract() or type_signature.is_placement(
    ) or type_signature.is_tensor():
        return transform_fn(type_signature)
コード例 #3
def type_to_tf_dtypes_and_shapes(type_spec: computation_types.Type):
    """Returns nested structures of tensor dtypes and shapes for a given TFF type.

  The returned dtypes and shapes match those used by `tf.data.Dataset`s to
  indicate the type and shape of their elements. They can be used, e.g., as
  arguments in constructing an iterator over a string handle.

    type_spec: A `computation_types.Type`, the type specification must be
      composed of only named tuples and tensors. In all named tuples that appear
      in the type spec, all the elements must be named.

    A pair of parallel nested structures with the dtypes and shapes of tensors
    defined in `type_spec`. The layout of the two structures returned is the
    same as the layout of the nested type defined by `type_spec`. Named tuples
    are represented as dictionaries.

    ValueError: if the `type_spec` is composed of something other than named
      tuples and tensors, or if any of the elements in named tuples are unnamed.
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return (type_spec.dtype, type_spec.shape)
    elif type_spec.is_tuple():
        elements = anonymous_tuple.to_elements(type_spec)
        if not elements:
            output_dtypes = []
            output_shapes = []
        elif elements[0][0] is not None:
            output_dtypes = collections.OrderedDict()
            output_shapes = collections.OrderedDict()
            for e in elements:
                element_name = e[0]
                element_spec = e[1]
                if element_name is None:
                    raise ValueError(
                        'When a sequence appears as a part of a parameter to a section '
                        'of TensorFlow code, in the type signature of elements of that '
                        'sequence all named tuples must have their elements explicitly '
                        'named, and this does not appear to be the case in {}.'
                element_output = type_to_tf_dtypes_and_shapes(element_spec)
                output_dtypes[element_name] = element_output[0]
                output_shapes[element_name] = element_output[1]
            output_dtypes = []
            output_shapes = []
            for e in elements:
                element_name = e[0]
                element_spec = e[1]
                if element_name is not None:
                    raise ValueError(
                        'When a sequence appears as a part of a parameter to a section '
                        'of TensorFlow code, in the type signature of elements of that '
                        'sequence all named tuples must have their elements explicitly '
                        'named, and this does not appear to be the case in {}.'
                element_output = type_to_tf_dtypes_and_shapes(element_spec)
        if type_spec.is_tuple_with_py_container():
            container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type(

            def build_py_container(elements):
                if (py_typecheck.is_named_tuple(container_type)
                        or py_typecheck.is_attrs(container_type)):
                    return container_type(**dict(elements))
                    return container_type(elements)

            output_dtypes = build_py_container(output_dtypes)
            output_shapes = build_py_container(output_shapes)
            output_dtypes = tuple(output_dtypes)
            output_shapes = tuple(output_shapes)
        return (output_dtypes, output_shapes)
        raise ValueError('Unsupported type {}.'.format(