def is_valid_bitwidth_type_for_value_type(
        bitwidth_type: computation_types.Type,
        value_type: computation_types.Type) -> bool:
    """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`."""

    # NOTE: this function is primarily a helper for `intrinsic_factory.py`'s
    # `federated_secure_sum` function.
    py_typecheck.check_type(bitwidth_type, computation_types.Type)
    py_typecheck.check_type(value_type, computation_types.Type)

    if value_type.is_tensor() and bitwidth_type.is_tensor():
        # Here, `value_type` refers to a tensor. Rather than check that
        # `bitwidth_type` is exactly the same, we check that it is a single integer,
        # since we want a single bitwidth integer per tensor.
        return bitwidth_type.dtype.is_integer and (
            bitwidth_type.shape.num_elements() == 1)
    elif value_type.is_tuple() and bitwidth_type.is_tuple():
        bitwidth_name_and_types = list(
            anonymous_tuple.iter_elements(bitwidth_type))
        value_name_and_types = list(anonymous_tuple.iter_elements(value_type))
        if len(bitwidth_name_and_types) != len(value_name_and_types):
            return False
        for (inner_bitwidth_name,
             inner_bitwidth_type), (inner_value_name, inner_value_type) in zip(
                 bitwidth_name_and_types, value_name_and_types):
            if inner_bitwidth_name != inner_value_name:
                return False
            if not is_valid_bitwidth_type_for_value_type(
                    inner_bitwidth_type, inner_value_type):
                return False
        return True
    else:
        return False
def is_sum_compatible(type_spec: computation_types.Type) -> bool:
    """Determines if `type_spec` is a type that can be added to itself.

  Types that are sum-compatible are composed of scalars of numeric types,
  possibly packaged into nested named tuples, and possibly federated. Types
  that are sum-incompatible include sequences, functions, abstract types,
  and placements.

  Args:
    type_spec: A `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is sum-compatible, `False` otherwise.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return is_numeric_dtype(type_spec.dtype)
    elif type_spec.is_tuple():
        return all(
            is_sum_compatible(v)
            for _, v in anonymous_tuple.iter_elements(type_spec))
    elif type_spec.is_federated():
        return is_sum_compatible(type_spec.member)
    else:
        return False
def type_to_tf_structure(type_spec: computation_types.Type):
    """Returns nested `tf.data.experimental.Structure` for a given TFF type.

  Args:
    type_spec: A `computation_types.Type`, the type specification must be
      composed of only named tuples and tensors. In all named tuples that appear
      in the type spec, all the elements must be named.

  Returns:
    An instance of `tf.data.experimental.Structure`, possibly nested, that
    corresponds to `type_spec`.

  Raises:
    ValueError: if the `type_spec` is composed of something other than named
      tuples and tensors, or if any of the elements in named tuples are unnamed.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return tf.TensorSpec(type_spec.shape, type_spec.dtype)
    elif type_spec.is_tuple():
        elements = anonymous_tuple.to_elements(type_spec)
        if not elements:
            raise ValueError('Empty tuples are unsupported.')
        element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements]
        named = element_outputs[0][0] is not None
        if not all((e[0] is not None) == named for e in element_outputs):
            raise ValueError('Tuple elements inconsistently named.')
        if not type_spec.is_tuple_with_py_container():
            if named:
                output = collections.OrderedDict(element_outputs)
            else:
                output = tuple(v for _, v in element_outputs)
        else:
            container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type(
                type_spec)
            if (py_typecheck.is_named_tuple(container_type)
                    or py_typecheck.is_attrs(container_type)):
                output = container_type(**dict(element_outputs))
            elif named:
                output = container_type(element_outputs)
            else:
                output = container_type(e if e[0] is not None else e[1]
                                        for e in element_outputs)
        return output
    else:
        raise ValueError('Unsupported type {}.'.format(
            py_typecheck.type_string(type(type_spec))))
def is_structure_of_integers(type_spec: computation_types.Type) -> bool:
    """Determines if `type_spec` is a structure of integers.

  Args:
    type_spec: A `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is a structure of integers, otherwise `False`.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        py_typecheck.check_type(type_spec.dtype, tf.DType)
        return type_spec.dtype.is_integer
    elif type_spec.is_tuple():
        return all(
            is_structure_of_integers(v)
            for _, v in anonymous_tuple.iter_elements(type_spec))
    elif type_spec.is_federated():
        return is_structure_of_integers(type_spec.member)
    else:
        return False
def is_average_compatible(type_spec: computation_types.Type) -> bool:
    """Determines if `type_spec` can be averaged.

  Types that are average-compatible are composed of numeric tensor types,
  either floating-point or complex, possibly packaged into nested named tuples,
  and possibly federated.

  Args:
    type_spec: a `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is average-compatible, `False` otherwise.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return type_spec.dtype.is_floating or type_spec.dtype.is_complex
    elif type_spec.is_tuple():
        return all(
            is_average_compatible(v)
            for _, v in anonymous_tuple.iter_elements(type_spec))
    elif type_spec.is_federated():
        return is_average_compatible(type_spec.member)
    else:
        return False
Beispiel #6
0
def type_to_tf_dtypes_and_shapes(type_spec: computation_types.Type):
    """Returns nested structures of tensor dtypes and shapes for a given TFF type.

  The returned dtypes and shapes match those used by `tf.data.Dataset`s to
  indicate the type and shape of their elements. They can be used, e.g., as
  arguments in constructing an iterator over a string handle.

  Args:
    type_spec: A `computation_types.Type`, the type specification must be
      composed of only named tuples and tensors. In all named tuples that appear
      in the type spec, all the elements must be named.

  Returns:
    A pair of parallel nested structures with the dtypes and shapes of tensors
    defined in `type_spec`. The layout of the two structures returned is the
    same as the layout of the nested type defined by `type_spec`. Named tuples
    are represented as dictionaries.

  Raises:
    ValueError: if the `type_spec` is composed of something other than named
      tuples and tensors, or if any of the elements in named tuples are unnamed.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return (type_spec.dtype, type_spec.shape)
    elif type_spec.is_tuple():
        elements = anonymous_tuple.to_elements(type_spec)
        if not elements:
            output_dtypes = []
            output_shapes = []
        elif elements[0][0] is not None:
            output_dtypes = collections.OrderedDict()
            output_shapes = collections.OrderedDict()
            for e in elements:
                element_name = e[0]
                element_spec = e[1]
                if element_name is None:
                    raise ValueError(
                        'When a sequence appears as a part of a parameter to a section '
                        'of TensorFlow code, in the type signature of elements of that '
                        'sequence all named tuples must have their elements explicitly '
                        'named, and this does not appear to be the case in {}.'
                        .format(type_spec))
                element_output = type_to_tf_dtypes_and_shapes(element_spec)
                output_dtypes[element_name] = element_output[0]
                output_shapes[element_name] = element_output[1]
        else:
            output_dtypes = []
            output_shapes = []
            for e in elements:
                element_name = e[0]
                element_spec = e[1]
                if element_name is not None:
                    raise ValueError(
                        'When a sequence appears as a part of a parameter to a section '
                        'of TensorFlow code, in the type signature of elements of that '
                        'sequence all named tuples must have their elements explicitly '
                        'named, and this does not appear to be the case in {}.'
                        .format(type_spec))
                element_output = type_to_tf_dtypes_and_shapes(element_spec)
                output_dtypes.append(element_output[0])
                output_shapes.append(element_output[1])
        if type_spec.python_container is not None:
            container_type = type_spec.python_container

            def build_py_container(elements):
                if (py_typecheck.is_named_tuple(container_type)
                        or py_typecheck.is_attrs(container_type)):
                    return container_type(**dict(elements))
                else:
                    return container_type(elements)

            output_dtypes = build_py_container(output_dtypes)
            output_shapes = build_py_container(output_shapes)
        else:
            output_dtypes = tuple(output_dtypes)
            output_shapes = tuple(output_shapes)
        return (output_dtypes, output_shapes)
    else:
        raise ValueError('Unsupported type {}.'.format(
            py_typecheck.type_string(type(type_spec))))
def create_binary_operator(
    operator, operand_type: computation_types.Type) -> pb.Computation:
  """Returns a tensorflow computation computing a binary operation.

  The returned computation has the type signature `(<T,T> -> U)`, where `T` is
  `operand_type` and `U` is the result of applying the `operator` to a tuple of
  type `<T,T>`

  Note: If `operand_type` is a `computation_types.NamedTupleType`, then
  `operator` will be applied pointwise. This places the burden on callers of
  this function to construct the correct values to pass into the returned
  function. For example, to divide `[2, 2]` by `2`, first `2` must be packed
  into the data structure `[x, x]`, before the division operator of the
  appropriate type is called.

  Args:
    operator: A callable taking two arguments representing the operation to
      encode For example: `tf.math.add`, `tf.math.multiply`, and
        `tf.math.divide`.
    operand_type: A `computation_types.Type` to use as the argument to the
      constructed binary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `operand_type` are violated or `operator`
      is not callable.
  """
  if not type_analysis.is_generic_op_compatible_type(operand_type):
    raise TypeError(
        'The type {} contains a type other than `computation_types.TensorType` '
        'and `computation_types.NamedTupleType`; this is disallowed in the '
        'generic operators.'.format(operand_type))
  py_typecheck.check_callable(operator)
  with tf.Graph().as_default() as graph:
    operand_1_value, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', operand_type, graph)
    operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph(
        'y', operand_type, graph)

    if operand_type is not None:
      if operand_type.is_tensor():
        result_value = operator(operand_1_value, operand_2_value)
      elif operand_type.is_tuple():
        result_value = anonymous_tuple.map_structure(operator, operand_1_value,
                                                     operand_2_value)
    else:
      raise TypeError(
          'Operand type {} cannot be used in generic operations. The call to '
          '`type_analysis.is_generic_op_compatible_type` has allowed it to '
          'pass, and should be updated.'.format(operand_type))
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result_value, graph)

  type_signature = computation_types.FunctionType(
      computation_types.NamedTupleType((operand_type, operand_type)),
      result_type)
  parameter_binding = pb.TensorFlow.Binding(
      tuple=pb.TensorFlow.NamedTupleBinding(
          element=[operand_1_binding, operand_2_binding]))
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature),
      tensorflow=tensorflow)
 def _concretize_abstract_types(
         abstract_type_spec: computation_types.Type,
         concrete_type_spec: computation_types.Type
 ) -> computation_types.Type:
     """Recursive helper function to construct concrete type spec."""
     if abstract_type_spec.is_abstract():
         bound_type = bound_abstract_types.get(str(
             abstract_type_spec.label))
         if bound_type:
             return bound_type
         else:
             bound_abstract_types[str(
                 abstract_type_spec.label)] = concrete_type_spec
             return concrete_type_spec
     elif abstract_type_spec.is_tensor():
         return abstract_type_spec
     elif abstract_type_spec.is_tuple():
         if not concrete_type_spec.is_tuple():
             raise TypeError(type_error_string)
         abstract_elements = anonymous_tuple.to_elements(abstract_type_spec)
         concrete_elements = anonymous_tuple.to_elements(concrete_type_spec)
         if len(abstract_elements) != len(concrete_elements):
             raise TypeError(type_error_string)
         concretized_tuple_elements = []
         for k in range(len(abstract_elements)):
             if abstract_elements[k][0] != concrete_elements[k][0]:
                 raise TypeError(type_error_string)
             concretized_tuple_elements.append(
                 (abstract_elements[k][0],
                  _concretize_abstract_types(abstract_elements[k][1],
                                             concrete_elements[k][1])))
         return computation_types.NamedTupleType(concretized_tuple_elements)
     elif abstract_type_spec.is_sequence():
         if not concrete_type_spec.is_sequence():
             raise TypeError(type_error_string)
         return computation_types.SequenceType(
             _concretize_abstract_types(abstract_type_spec.element,
                                        concrete_type_spec.element))
     elif abstract_type_spec.is_function():
         if not concrete_type_spec.is_function():
             raise TypeError(type_error_string)
         if abstract_type_spec.parameter is None:
             if concrete_type_spec.parameter is not None:
                 return TypeError(type_error_string)
             concretized_param = None
         else:
             concretized_param = _concretize_abstract_types(
                 abstract_type_spec.parameter, concrete_type_spec.parameter)
         concretized_result = _concretize_abstract_types(
             abstract_type_spec.result, concrete_type_spec.result)
         return computation_types.FunctionType(concretized_param,
                                               concretized_result)
     elif abstract_type_spec.is_placement():
         if not concrete_type_spec.is_placement():
             raise TypeError(type_error_string)
         return abstract_type_spec
     elif abstract_type_spec.is_federated():
         if not concrete_type_spec.is_federated():
             raise TypeError(type_error_string)
         new_member = _concretize_abstract_types(abstract_type_spec.member,
                                                 concrete_type_spec.member)
         return computation_types.FederatedType(
             new_member, abstract_type_spec.placement,
             abstract_type_spec.all_equal)
     else:
         raise TypeError(
             'Unexpected abstract typespec {}.'.format(abstract_type_spec))
Beispiel #9
0
def create_binary_operator_with_upcast(
        type_signature: computation_types.Type,
        operator: Callable[[Any, Any], Any]) -> pb.Computation:
    """Creates TF computation upcasting its argument and applying `operator`.

  Args:
    type_signature: Value convertible to `computation_types.NamedTupleType`,
      with two elements, both of the same type or the second able to be upcast
      to the first, as explained in `apply_binary_operator_with_upcast`, and
      both containing only tuples and tensors in their type tree.
    operator: Callable defining the operator.

  Returns:
    A `building_blocks.CompiledComputation` encapsulating a function which
    upcasts the second element of its argument and applies the binary
    operator.
  """

    py_typecheck.check_callable(operator)
    type_signature = computation_types.to_type(type_signature)
    type_analysis.check_tensorflow_compatible_type(type_signature)
    if not type_signature.is_tuple() or len(type_signature) != 2:
        raise TypeError(
            'To apply a binary operator, we must by definition have an '
            'argument which is a `NamedTupleType` with 2 elements; '
            'asked to create a binary operator for type: {t}'.format(
                t=type_signature))
    if type_analysis.contains(type_signature, lambda t: t.is_sequence()):
        raise TypeError('Applying binary operators in TensorFlow is only '
                        'supported on Tensors and NamedTupleTypes; you '
                        'passed {t} which contains a SequenceType.'.format(
                            t=type_signature))

    def _pack_into_type(to_pack, type_spec):
        """Pack Tensor value `to_pack` into the nested structure `type_spec`."""
        if type_spec.is_tuple():
            elem_iter = anonymous_tuple.iter_elements(type_spec)
            return anonymous_tuple.AnonymousTuple([
                (elem_name, _pack_into_type(to_pack, elem_type))
                for elem_name, elem_type in elem_iter
            ])
        elif type_spec.is_tensor():
            return tf.broadcast_to(to_pack, type_spec.shape)

    with tf.Graph().as_default() as graph:
        first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', type_signature[0], graph)
        operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph(
            'y', type_signature[1], graph)
        if type_signature[0].is_equivalent_to(type_signature[1]):
            second_arg = operand_2_value
        else:
            second_arg = _pack_into_type(operand_2_value, type_signature[0])

        if type_signature[0].is_tensor():
            result_value = operator(first_arg, second_arg)
        elif type_signature[0].is_tuple():
            result_value = anonymous_tuple.map_structure(
                operator, first_arg, second_arg)
        else:
            raise TypeError(
                'Encountered unexpected type {t}; can only handle Tensor '
                'and NamedTupleTypes.'.format(t=type_signature[0]))

    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result_value, graph)

    type_signature = computation_types.FunctionType(type_signature,
                                                    result_type)
    parameter_binding = pb.TensorFlow.Binding(
        tuple=pb.TensorFlow.NamedTupleBinding(
            element=[operand_1_binding, operand_2_binding]))
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
Beispiel #10
0
def transform_type_postorder(
    type_signature: computation_types.Type,
    transform_fn: Callable[[computation_types.Type],
                           Tuple[computation_types.Type, bool]]):
    """Walks type tree of `type_signature` postorder, calling `transform_fn`.

  Args:
    type_signature: Instance of `computation_types.Type` to transform
      recursively.
    transform_fn: Transformation function to apply to each node in the type tree
      of `type_signature`. Must be instance of Python function type.

  Returns:
    A possibly transformed version of `type_signature`, with each node in its
    tree the result of applying `transform_fn` to the corresponding node in
    `type_signature`.

  Raises:
    TypeError: If the types don't match the specification above.
  """
    py_typecheck.check_type(type_signature, computation_types.Type)
    py_typecheck.check_callable(transform_fn)
    if type_signature.is_federated():
        transformed_member, member_mutated = transform_type_postorder(
            type_signature.member, transform_fn)
        if member_mutated:
            type_signature = computation_types.FederatedType(
                transformed_member, type_signature.placement,
                type_signature.all_equal)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or member_mutated
    elif type_signature.is_sequence():
        transformed_element, element_mutated = transform_type_postorder(
            type_signature.element, transform_fn)
        if element_mutated:
            type_signature = computation_types.SequenceType(
                transformed_element)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or element_mutated
    elif type_signature.is_function():
        if type_signature.parameter is not None:
            transformed_parameter, parameter_mutated = transform_type_postorder(
                type_signature.parameter, transform_fn)
        else:
            transformed_parameter, parameter_mutated = (None, False)
        transformed_result, result_mutated = transform_type_postorder(
            type_signature.result, transform_fn)
        if parameter_mutated or result_mutated:
            type_signature = computation_types.FunctionType(
                transformed_parameter, transformed_result)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, (type_signature_mutated or parameter_mutated
                                or result_mutated)
    elif type_signature.is_tuple():
        elements = []
        elements_mutated = False
        for element in anonymous_tuple.iter_elements(type_signature):
            transformed_element, element_mutated = transform_type_postorder(
                element[1], transform_fn)
            elements_mutated = elements_mutated or element_mutated
            elements.append((element[0], transformed_element))
        if elements_mutated:
            if type_signature.is_tuple_with_py_container():
                container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type(
                    type_signature)
                type_signature = computation_types.NamedTupleTypeWithPyContainerType(
                    elements, container_type)
            else:
                type_signature = computation_types.NamedTupleType(elements)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or elements_mutated
    elif type_signature.is_abstract() or type_signature.is_placement(
    ) or type_signature.is_tensor():
        return transform_fn(type_signature)
 def _is_compatible(t: computation_types.Type) -> bool:
     return type_analysis.contains_only(
         t, lambda t: t.is_tuple() or t.is_tensor() or t.is_federated())