コード例 #1
0
def type_to_tf_structure(type_spec: computation_types.Type):
    """Returns nested `tf.data.experimental.Structure` for a given TFF type.

  Args:
    type_spec: A `computation_types.Type`, the type specification must be
      composed of only named tuples and tensors. In all named tuples that appear
      in the type spec, all the elements must be named.

  Returns:
    An instance of `tf.data.experimental.Structure`, possibly nested, that
    corresponds to `type_spec`.

  Raises:
    ValueError: if the `type_spec` is composed of something other than named
      tuples and tensors, or if any of the elements in named tuples are unnamed.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return tf.TensorSpec(type_spec.shape, type_spec.dtype)
    elif type_spec.is_tuple():
        elements = anonymous_tuple.to_elements(type_spec)
        if not elements:
            raise ValueError('Empty tuples are unsupported.')
        element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements]
        named = element_outputs[0][0] is not None
        if not all((e[0] is not None) == named for e in element_outputs):
            raise ValueError('Tuple elements inconsistently named.')
        if not type_spec.is_tuple_with_py_container():
            if named:
                output = collections.OrderedDict(element_outputs)
            else:
                output = tuple(v for _, v in element_outputs)
        else:
            container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type(
                type_spec)
            if (py_typecheck.is_named_tuple(container_type)
                    or py_typecheck.is_attrs(container_type)):
                output = container_type(**dict(element_outputs))
            elif named:
                output = container_type(element_outputs)
            else:
                output = container_type(e if e[0] is not None else e[1]
                                        for e in element_outputs)
        return output
    else:
        raise ValueError('Unsupported type {}.'.format(
            py_typecheck.type_string(type(type_spec))))
コード例 #2
0
def transform_type_postorder(
    type_signature: computation_types.Type,
    transform_fn: Callable[[computation_types.Type],
                           Tuple[computation_types.Type, bool]]):
    """Walks type tree of `type_signature` postorder, calling `transform_fn`.

  Args:
    type_signature: Instance of `computation_types.Type` to transform
      recursively.
    transform_fn: Transformation function to apply to each node in the type tree
      of `type_signature`. Must be instance of Python function type.

  Returns:
    A possibly transformed version of `type_signature`, with each node in its
    tree the result of applying `transform_fn` to the corresponding node in
    `type_signature`.

  Raises:
    TypeError: If the types don't match the specification above.
  """
    py_typecheck.check_type(type_signature, computation_types.Type)
    py_typecheck.check_callable(transform_fn)
    if type_signature.is_federated():
        transformed_member, member_mutated = transform_type_postorder(
            type_signature.member, transform_fn)
        if member_mutated:
            type_signature = computation_types.FederatedType(
                transformed_member, type_signature.placement,
                type_signature.all_equal)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or member_mutated
    elif type_signature.is_sequence():
        transformed_element, element_mutated = transform_type_postorder(
            type_signature.element, transform_fn)
        if element_mutated:
            type_signature = computation_types.SequenceType(
                transformed_element)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or element_mutated
    elif type_signature.is_function():
        if type_signature.parameter is not None:
            transformed_parameter, parameter_mutated = transform_type_postorder(
                type_signature.parameter, transform_fn)
        else:
            transformed_parameter, parameter_mutated = (None, False)
        transformed_result, result_mutated = transform_type_postorder(
            type_signature.result, transform_fn)
        if parameter_mutated or result_mutated:
            type_signature = computation_types.FunctionType(
                transformed_parameter, transformed_result)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, (type_signature_mutated or parameter_mutated
                                or result_mutated)
    elif type_signature.is_tuple():
        elements = []
        elements_mutated = False
        for element in anonymous_tuple.iter_elements(type_signature):
            transformed_element, element_mutated = transform_type_postorder(
                element[1], transform_fn)
            elements_mutated = elements_mutated or element_mutated
            elements.append((element[0], transformed_element))
        if elements_mutated:
            if type_signature.is_tuple_with_py_container():
                container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type(
                    type_signature)
                type_signature = computation_types.NamedTupleTypeWithPyContainerType(
                    elements, container_type)
            else:
                type_signature = computation_types.NamedTupleType(elements)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or elements_mutated
    elif type_signature.is_abstract() or type_signature.is_placement(
    ) or type_signature.is_tensor():
        return transform_fn(type_signature)
コード例 #3
0
def type_to_tf_dtypes_and_shapes(type_spec: computation_types.Type):
    """Returns nested structures of tensor dtypes and shapes for a given TFF type.

  The returned dtypes and shapes match those used by `tf.data.Dataset`s to
  indicate the type and shape of their elements. They can be used, e.g., as
  arguments in constructing an iterator over a string handle.

  Args:
    type_spec: A `computation_types.Type`, the type specification must be
      composed of only named tuples and tensors. In all named tuples that appear
      in the type spec, all the elements must be named.

  Returns:
    A pair of parallel nested structures with the dtypes and shapes of tensors
    defined in `type_spec`. The layout of the two structures returned is the
    same as the layout of the nested type defined by `type_spec`. Named tuples
    are represented as dictionaries.

  Raises:
    ValueError: if the `type_spec` is composed of something other than named
      tuples and tensors, or if any of the elements in named tuples are unnamed.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return (type_spec.dtype, type_spec.shape)
    elif type_spec.is_tuple():
        elements = anonymous_tuple.to_elements(type_spec)
        if not elements:
            output_dtypes = []
            output_shapes = []
        elif elements[0][0] is not None:
            output_dtypes = collections.OrderedDict()
            output_shapes = collections.OrderedDict()
            for e in elements:
                element_name = e[0]
                element_spec = e[1]
                if element_name is None:
                    raise ValueError(
                        'When a sequence appears as a part of a parameter to a section '
                        'of TensorFlow code, in the type signature of elements of that '
                        'sequence all named tuples must have their elements explicitly '
                        'named, and this does not appear to be the case in {}.'
                        .format(type_spec))
                element_output = type_to_tf_dtypes_and_shapes(element_spec)
                output_dtypes[element_name] = element_output[0]
                output_shapes[element_name] = element_output[1]
        else:
            output_dtypes = []
            output_shapes = []
            for e in elements:
                element_name = e[0]
                element_spec = e[1]
                if element_name is not None:
                    raise ValueError(
                        'When a sequence appears as a part of a parameter to a section '
                        'of TensorFlow code, in the type signature of elements of that '
                        'sequence all named tuples must have their elements explicitly '
                        'named, and this does not appear to be the case in {}.'
                        .format(type_spec))
                element_output = type_to_tf_dtypes_and_shapes(element_spec)
                output_dtypes.append(element_output[0])
                output_shapes.append(element_output[1])
        if type_spec.is_tuple_with_py_container():
            container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type(
                type_spec)

            def build_py_container(elements):
                if (py_typecheck.is_named_tuple(container_type)
                        or py_typecheck.is_attrs(container_type)):
                    return container_type(**dict(elements))
                else:
                    return container_type(elements)

            output_dtypes = build_py_container(output_dtypes)
            output_shapes = build_py_container(output_shapes)
        else:
            output_dtypes = tuple(output_dtypes)
            output_shapes = tuple(output_shapes)
        return (output_dtypes, output_shapes)
    else:
        raise ValueError('Unsupported type {}.'.format(
            py_typecheck.type_string(type(type_spec))))