Ejemplo n.º 1
0
def is_valid_bitwidth_type_for_value_type(
    bitwidth_type: computation_types.Type,
    value_type: computation_types.Type) -> bool:
  """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`."""

  # NOTE: this function is primarily a helper for `intrinsic_factory.py`'s
  # `federated_secure_sum` function.
  py_typecheck.check_type(bitwidth_type, computation_types.Type)
  py_typecheck.check_type(value_type, computation_types.Type)

  if bitwidth_type.is_tensor():
    # This condition applies to both `value_type` being a tensor or structure,
    # as the same integer bitwidth can be used for all values in the structure.
    return bitwidth_type.dtype.is_integer and (
        bitwidth_type.shape.num_elements() == 1)
  elif value_type.is_struct() and bitwidth_type.is_struct():
    bitwidth_name_and_types = list(structure.iter_elements(bitwidth_type))
    value_name_and_types = list(structure.iter_elements(value_type))
    if len(bitwidth_name_and_types) != len(value_name_and_types):
      return False
    for (inner_bitwidth_name,
         inner_bitwidth_type), (inner_value_name, inner_value_type) in zip(
             bitwidth_name_and_types, value_name_and_types):
      if inner_bitwidth_name != inner_value_name:
        return False
      if not is_valid_bitwidth_type_for_value_type(inner_bitwidth_type,
                                                   inner_value_type):
        return False
    return True
  else:
    return False
Ejemplo n.º 2
0
def is_valid_bitwidth_type_for_value_type(
    bitwidth_type: computation_types.Type,
    value_type: computation_types.Type) -> bool:
  """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`."""

  # NOTE: this function is primarily a helper for `intrinsic_factory.py`'s
  # `federated_secure_sum` function.
  py_typecheck.check_type(bitwidth_type, computation_types.Type)
  py_typecheck.check_type(value_type, computation_types.Type)

  if value_type.is_tensor() and bitwidth_type.is_tensor():
    # Here, `value_type` refers to a tensor. Rather than check that
    # `bitwidth_type` is exactly the same, we check that it is a single integer,
    # since we want a single bitwidth integer per tensor.
    return bitwidth_type.dtype.is_integer and (
        bitwidth_type.shape.num_elements() == 1)
  elif value_type.is_struct() and bitwidth_type.is_struct():
    bitwidth_name_and_types = list(structure.iter_elements(bitwidth_type))
    value_name_and_types = list(structure.iter_elements(value_type))
    if len(bitwidth_name_and_types) != len(value_name_and_types):
      return False
    for (inner_bitwidth_name,
         inner_bitwidth_type), (inner_value_name, inner_value_type) in zip(
             bitwidth_name_and_types, value_name_and_types):
      if inner_bitwidth_name != inner_value_name:
        return False
      if not is_valid_bitwidth_type_for_value_type(inner_bitwidth_type,
                                                   inner_value_type):
        return False
    return True
  else:
    return False
Ejemplo n.º 3
0
def is_sum_compatible(type_spec: computation_types.Type) -> bool:
    """Determines if `type_spec` is a type that can be added to itself.

  Types that are sum-compatible are composed of scalars of numeric types,
  possibly packaged into nested named tuples, and possibly federated. Types
  that are sum-incompatible include sequences, functions, abstract types,
  and placements.

  Args:
    type_spec: A `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is sum-compatible, `False` otherwise.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return is_numeric_dtype(
            type_spec.dtype) and type_spec.shape.is_fully_defined()
    elif type_spec.is_struct():
        return all(
            is_sum_compatible(v)
            for _, v in structure.iter_elements(type_spec))
    elif type_spec.is_federated():
        return is_sum_compatible(type_spec.member)
    else:
        return False
Ejemplo n.º 4
0
def create_identity(type_signature: computation_types.Type) -> ProtoAndType:
    """Returns a tensorflow computation representing an identity function.

  The returned computation has the type signature `(T -> T)`, where `T` is
  `type_signature`. NOTE: if `T` contains `computation_types.StructType`s
  without an associated container type, they will be given the container type
  `tuple` by this function.

  Args:
    type_signature: A `computation_types.Type` to use as the parameter type and
      result type of the identity function.

  Raises:
    TypeError: If `type_signature` contains any types which cannot appear in
      TensorFlow bindings.
  """
    type_analysis.check_tensorflow_compatible_type(type_signature)
    parameter_type = type_signature
    if parameter_type is None:
        raise TypeError('TensorFlow identity cannot be created for NoneType.')

    # TF relies on feeds not-identical to fetches in certain circumstances.
    if type_signature.is_tensor() or type_signature.is_sequence():
        identity_fn = tf.identity
    elif type_signature.is_struct():
        identity_fn = functools.partial(structure.map_structure, tf.identity)
    else:
        raise NotImplementedError(
            f'TensorFlow identity cannot be created for type {type_signature}')

    return create_computation_for_py_fn(identity_fn, parameter_type)
Ejemplo n.º 5
0
 def ingest_value(
     self, value: Any, type_signature: computation_types.Type
 ) -> executor_value_base.ExecutorValue:
   if type_signature is not None:
     if type_signature.is_federated():
       self._check_strategy_compatible_with_placement(type_signature.placement)
     elif type_signature.is_function() and type_signature.result.is_federated(
     ):
       self._check_strategy_compatible_with_placement(
           type_signature.result.placement)
   return FederatedResolvingStrategyValue(value, type_signature)
Ejemplo n.º 6
0
def _to_sequence_internal_rep(
        *, value: Any, type_spec: computation_types.Type) -> tf.data.Dataset:
    """Ingests `value`, converting to an eager dataset."""
    if isinstance(value, list):
        value = tensorflow_utils.make_data_set_from_elements(
            None, value, type_spec.element)
    py_typecheck.check_type(value,
                            type_conversions.TF_DATASET_REPRESENTATION_TYPES)
    element_type = computation_types.to_type(value.element_spec)
    value_type = computation_types.SequenceType(element_type)
    type_spec.check_assignable_from(value_type)
    return value
Ejemplo n.º 7
0
def _ensure_deserialized_types_compatible(
    previous_type: Optional[computation_types.Type],
    next_type: computation_types.Type) -> computation_types.Type:
  """Ensures one of `previous_type` or `next_type` is assignable to the other.

  Returns the type which is assignable from the other.

  Args:
    previous_type: Instance of `computation_types.Type` or `None`.
    next_type: Instance of `computation_types.Type`.

  Returns:
    The supertype of `previous_type` and `next_type`.

  Raises:
    TypeError if neither type is assignable from the other.
  """
  if previous_type is None:
    return next_type
  else:
    if next_type.is_assignable_from(previous_type):
      return next_type
    elif previous_type.is_assignable_from(next_type):
      return previous_type
    raise TypeError('Type mismatch checking member assignability under a '
                    'federated value. Deserialized type {} is incompatible '
                    'with previously deserialized {}.'.format(
                        next_type, previous_type))
def type_to_tf_structure(type_spec: computation_types.Type):
    """Returns nested `tf.data.experimental.Structure` for a given TFF type.

  Args:
    type_spec: A `computation_types.Type`, the type specification must be
      composed of only named tuples and tensors. In all named tuples that appear
      in the type spec, all the elements must be named.

  Returns:
    An instance of `tf.data.experimental.Structure`, possibly nested, that
    corresponds to `type_spec`.

  Raises:
    ValueError: if the `type_spec` is composed of something other than named
      tuples and tensors, or if any of the elements in named tuples are unnamed.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return tf.TensorSpec(type_spec.shape, type_spec.dtype)
    elif type_spec.is_tuple():
        elements = anonymous_tuple.to_elements(type_spec)
        if not elements:
            raise ValueError('Empty tuples are unsupported.')
        element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements]
        named = element_outputs[0][0] is not None
        if not all((e[0] is not None) == named for e in element_outputs):
            raise ValueError('Tuple elements inconsistently named.')
        if not type_spec.is_tuple_with_py_container():
            if named:
                output = collections.OrderedDict(element_outputs)
            else:
                output = tuple(v for _, v in element_outputs)
        else:
            container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type(
                type_spec)
            if (py_typecheck.is_named_tuple(container_type)
                    or py_typecheck.is_attrs(container_type)):
                output = container_type(**dict(element_outputs))
            elif named:
                output = container_type(element_outputs)
            else:
                output = container_type(e if e[0] is not None else e[1]
                                        for e in element_outputs)
        return output
    else:
        raise ValueError('Unsupported type {}.'.format(
            py_typecheck.type_string(type(type_spec))))
Ejemplo n.º 9
0
def _visit_preorder(
    type_spec: computation_types.Type,
    fn: Callable[[computation_types.Type, T], T],
    context: T,
):
    context = fn(type_spec, context)
    for child_type in type_spec.children():
        visit_preorder(child_type, fn, context)
Ejemplo n.º 10
0
def contains(type_signature: computation_types.Type,
             predicate: _TypePredicate) -> bool:
  """Checks if `type_signature` contains any types that pass `predicate`."""
  if predicate(type_signature):
    return True
  for child in type_signature.children():
    if contains(child, predicate):
      return True
  return False
Ejemplo n.º 11
0
def _check_type_is_fn(
    target: computation_types.Type,
    name: str,
    err_fn: Callable[
        [str], Exception] = transformations.CanonicalFormCompilationError,
):
    if not target.is_function():
        raise err_fn(f'Expected {name} to be a function, but {name} had type '
                     f'{target}.')
Ejemplo n.º 12
0
 def transform_to_tff_known_type(
     type_spec: computation_types.Type) -> Tuple[computation_types.Type, bool]:
   """Transforms `StructType` to `StructWithPythonType`."""
   if type_spec.is_struct() and not type_spec.is_struct_with_python():
     field_is_named = tuple(
         name is not None for name, _ in structure.iter_elements(type_spec))
     has_names = any(field_is_named)
     is_all_named = all(field_is_named)
     if is_all_named:
       return computation_types.StructWithPythonType(
           elements=structure.iter_elements(type_spec),
           container_type=collections.OrderedDict), True
     elif not has_names:
       return computation_types.StructWithPythonType(
           elements=structure.iter_elements(type_spec),
           container_type=tuple), True
     else:
       raise TypeError('Cannot represent TFF type in TF because it contains '
                       f'partially named structures. Type: {type_spec}')
   return type_spec, False
Ejemplo n.º 13
0
def is_structure_of_integers(type_spec: computation_types.Type) -> bool:
  """Determines if `type_spec` is a structure of integers.

  Args:
    type_spec: A `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is a structure of integers, otherwise `False`.
  """
  py_typecheck.check_type(type_spec, computation_types.Type)
  if type_spec.is_tensor():
    py_typecheck.check_type(type_spec.dtype, tf.DType)
    return type_spec.dtype.is_integer
  elif type_spec.is_struct():
    return all(
        is_structure_of_integers(v)
        for _, v in structure.iter_elements(type_spec))
  elif type_spec.is_federated():
    return is_structure_of_integers(type_spec.member)
  else:
    return False
Ejemplo n.º 14
0
def count(type_signature: computation_types.Type,
          predicate: _TypePredicate) -> int:
  """Returns the number of types in `type_signature` matching `predicate`.

  Args:
    type_signature: A tree of `computation_type.Type`s to count.
    predicate: A Python function that takes a type as a parameter and returns a
      boolean value.
  """
  counter = 1 if predicate(type_signature) else 0
  counter += sum(count(child, predicate) for child in type_signature.children())
  return counter
Ejemplo n.º 15
0
def create_identity(type_signature: computation_types.Type) -> ProtoAndType:
  """Returns a tensorflow computation representing an identity function.

  The returned computation has the type signature `(T -> T)`, where `T` is
  `type_signature`. NOTE: if `T` contains `computation_types.StructType`s
  without an associated container type, they will be given the container type
  `tuple` by this function.

  Args:
    type_signature: A `computation_types.Type` to use as the parameter type and
      result type of the identity function.

  Raises:
    TypeError: If `type_signature` contains any types which cannot appear in
      TensorFlow bindings.
  """
  type_analysis.check_tensorflow_compatible_type(type_signature)
  parameter_type = type_signature
  if parameter_type is None:
    raise TypeError('TensorFlow identity cannot be created for NoneType.')

  with tf.Graph().as_default() as graph:
    parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', parameter_type, graph)
    # TF relies on feeds not-identical to fetches in certain circumstances.
    if type_signature.is_tensor():
      parameter_value = tf.identity(parameter_value)
    elif type_signature.is_struct():
      parameter_value = structure.map_structure(tf.identity, parameter_value)
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        parameter_value, graph)

  type_signature = computation_types.FunctionType(parameter_type, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
Ejemplo n.º 16
0
def is_binary_op_with_upcast_compatible_pair(
    possibly_nested_type: Optional[computation_types.Type],
    type_to_upcast: computation_types.Type) -> bool:
  """Checks unambiguity in applying `type_to_upcast` to `possibly_nested_type`.

  That is, checks that either these types are equivalent and contain only
  tuples and tensors, or that
  `possibly_nested_type` is perhaps a nested structure containing only tensors
  with `dtype` of `type_to_upcast` at the leaves, where `type_to_upcast` must
  be a scalar tensor type. Notice that this relationship is not symmetric,
  since binary operators need not respect this symmetry in general.
  For example, it makes perfect sence to divide a nested structure of tensors
  by a scalar, but not the other way around.

  Args:
    possibly_nested_type: A `computation_types.Type`, or `None`.
    type_to_upcast: A `computation_types.Type`, or `None`.

  Returns:
    Boolean indicating whether `type_to_upcast` can be upcast to
    `possibly_nested_type` in the manner described above.
  """
  if possibly_nested_type is not None:
    py_typecheck.check_type(possibly_nested_type, computation_types.Type)
  if type_to_upcast is not None:
    py_typecheck.check_type(type_to_upcast, computation_types.Type)
  if not (is_generic_op_compatible_type(possibly_nested_type) and
          is_generic_op_compatible_type(type_to_upcast)):
    return False
  if possibly_nested_type is None:
    return type_to_upcast is None
  if possibly_nested_type.is_equivalent_to(type_to_upcast):
    return True
  if not (type_to_upcast.is_tensor() and type_to_upcast.shape == tf.TensorShape(
      ())):
    return False

  types_are_ok = [True]

  only_allowed_dtype = type_to_upcast.dtype

  def _check_tensor_types(type_spec):
    if type_spec.is_tensor() and type_spec.dtype != only_allowed_dtype:
      types_are_ok[0] = False
    return type_spec, False

  type_transformations.transform_type_postorder(possibly_nested_type,
                                                _check_tensor_types)

  return types_are_ok[0]
Ejemplo n.º 17
0
def is_average_compatible(type_spec: computation_types.Type) -> bool:
  """Determines if `type_spec` can be averaged.

  Types that are average-compatible are composed of numeric tensor types,
  either floating-point or complex, possibly packaged into nested named tuples,
  and possibly federated.

  Args:
    type_spec: a `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is average-compatible, `False` otherwise.
  """
  py_typecheck.check_type(type_spec, computation_types.Type)
  if type_spec.is_tensor():
    return type_spec.dtype.is_floating or type_spec.dtype.is_complex
  elif type_spec.is_struct():
    return all(
        is_average_compatible(v) for _, v in structure.iter_elements(type_spec))
  elif type_spec.is_federated():
    return is_average_compatible(type_spec.member)
  else:
    return False
Ejemplo n.º 18
0
def _to_tensor_internal_rep(*, value: Any,
                            type_spec: computation_types.Type) -> tf.Tensor:
    """Normalizes tensor-like value to a tf.Tensor."""
    if not tf.is_tensor(value):
        value = tf.convert_to_tensor(value, dtype=type_spec.dtype)
    elif hasattr(value, 'read_value'):
        # a tf.Variable-like result, get a proper tensor.
        value = value.read_value()
    value_type = (computation_types.TensorType(value.dtype.base_dtype,
                                               value.shape))
    if not type_spec.is_assignable_from(value_type):
        raise TypeError(
            'The apparent type {} of a tensor {} does not match the expected '
            'type {}.'.format(value_type, value, type_spec))
    return value
Ejemplo n.º 19
0
def is_structure_of_floats(type_spec: computation_types.Type) -> bool:
    """Determines if `type_spec` is a structure of floats.

  Note that an empty `computation_types.StructType` will return `True`, as it
  does not contain any non-floating types.

  Args:
    type_spec: A `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is a structure of floats, otherwise `False`.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        py_typecheck.check_type(type_spec.dtype, tf.DType)
        return type_spec.dtype.is_floating
    elif type_spec.is_struct():
        return all(
            is_structure_of_floats(v)
            for _, v in structure.iter_elements(type_spec))
    elif type_spec.is_federated():
        return is_structure_of_floats(type_spec.member)
    else:
        return False
Ejemplo n.º 20
0
def check_type(value: Any, type_spec: computation_types.Type):
  """Checks whether `val` is of TFF type `type_spec`.

  Args:
    value: The object to check.
    type_spec: A `computation_types.Type`, the type that `value` is checked
      against.

  Raises:
    TypeError: If the infferred type of `value` is not `type_spec`.
  """
  py_typecheck.check_type(type_spec, computation_types.Type)
  value_type = type_conversions.infer_type(value)
  if not type_spec.is_assignable_from(value_type):
    raise TypeError(
        'Expected TFF type {}, which is not assignable from {}.'.format(
            type_spec, value_type))
Ejemplo n.º 21
0
def visit_preorder(type_signature: computation_types.Type,
                   fn: Callable[[computation_types.Type, T], T], context: T):
  """Recursively calls `fn` on the possibly nested structure `type_signature`.

  Walks the tree in a preorder manner. Updates `context` on the way down with
  the appropriate information, as defined in `fn`.

  Args:
    type_signature: A `computation_types.Type`.
    fn: A function to apply to each of the constituent elements of
      `type_signature` with the argument `context`. Must return an updated
      version of `context` which incorporated the information we'd like to track
      as we move down the type tree.
    context: Initial state of information to be passed down the tree.
  """
  context = fn(type_signature, context)
  for child_type in type_signature.children():
    visit_preorder(child_type, fn, context)
Ejemplo n.º 22
0
def reconcile_value_type_with_type_spec(
    value_type: computation_types.Type,
    type_spec: Optional[computation_types.Type]) -> computation_types.Type:
  """Reconciles a pair of types.

  Args:
    value_type: An instance of `tff.Type`.
    type_spec: An instance of `tff.Type`, or `None`.

  Returns:
    Either `value_type` if `type_spec` is `None`, or `type_spec` if `type_spec`
    is not `None` and rquivalent with `value_type`.

  Raises:
    TypeError: If arguments are of incompatible types.
  """
  py_typecheck.check_type(value_type, computation_types.Type)
  if type_spec is not None:
    py_typecheck.check_type(value_type, computation_types.Type)
    if not value_type.is_equivalent_to(type_spec):
      raise TypeError('Expected a value of type {}, found {}.'.format(
          type_spec, value_type))
  return type_spec if type_spec is not None else value_type
Ejemplo n.º 23
0
def type_to_tf_dtypes_and_shapes(type_spec: computation_types.Type):
    """Returns nested structures of tensor dtypes and shapes for a given TFF type.

  The returned dtypes and shapes match those used by `tf.data.Dataset`s to
  indicate the type and shape of their elements. They can be used, e.g., as
  arguments in constructing an iterator over a string handle.

  Args:
    type_spec: A `computation_types.Type`, the type specification must be
      composed of only named tuples and tensors. In all named tuples that appear
      in the type spec, all the elements must be named.

  Returns:
    A pair of parallel nested structures with the dtypes and shapes of tensors
    defined in `type_spec`. The layout of the two structures returned is the
    same as the layout of the nested type defined by `type_spec`. Named tuples
    are represented as dictionaries.

  Raises:
    ValueError: if the `type_spec` is composed of something other than named
      tuples and tensors, or if any of the elements in named tuples are unnamed.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return (type_spec.dtype, type_spec.shape)
    elif type_spec.is_struct():
        elements = structure.to_elements(type_spec)
        if not elements:
            output_dtypes = []
            output_shapes = []
        elif elements[0][0] is not None:
            output_dtypes = collections.OrderedDict()
            output_shapes = collections.OrderedDict()
            for e in elements:
                element_name = e[0]
                element_spec = e[1]
                if element_name is None:
                    raise ValueError(
                        'When a sequence appears as a part of a parameter to a section '
                        'of TensorFlow code, in the type signature of elements of that '
                        'sequence all named tuples must have their elements explicitly '
                        'named, and this does not appear to be the case in {}.'
                        .format(type_spec))
                element_output = type_to_tf_dtypes_and_shapes(element_spec)
                output_dtypes[element_name] = element_output[0]
                output_shapes[element_name] = element_output[1]
        else:
            output_dtypes = []
            output_shapes = []
            for e in elements:
                element_name = e[0]
                element_spec = e[1]
                if element_name is not None:
                    raise ValueError(
                        'When a sequence appears as a part of a parameter to a section '
                        'of TensorFlow code, in the type signature of elements of that '
                        'sequence all named tuples must have their elements explicitly '
                        'named, and this does not appear to be the case in {}.'
                        .format(type_spec))
                element_output = type_to_tf_dtypes_and_shapes(element_spec)
                output_dtypes.append(element_output[0])
                output_shapes.append(element_output[1])
        if type_spec.python_container is not None:
            container_type = type_spec.python_container

            def build_py_container(elements):
                if (py_typecheck.is_named_tuple(container_type)
                        or py_typecheck.is_attrs(container_type)):
                    return container_type(**dict(elements))
                else:
                    return container_type(elements)

            output_dtypes = build_py_container(output_dtypes)
            output_shapes = build_py_container(output_shapes)
        else:
            output_dtypes = tuple(output_dtypes)
            output_shapes = tuple(output_shapes)
        return (output_dtypes, output_shapes)
    else:
        raise ValueError('Unsupported type {}.'.format(
            py_typecheck.type_string(type(type_spec))))
Ejemplo n.º 24
0
def _is_two_tuple(t: computation_types.Type) -> bool:
    return t.is_struct() and len(t) == 2
Ejemplo n.º 25
0
 def _concretize_abstract_types(
     abstract_type_spec: computation_types.Type,
     concrete_type_spec: computation_types.Type) -> computation_types.Type:
   """Recursive helper function to construct concrete type spec."""
   if abstract_type_spec.is_abstract():
     bound_type = bound_abstract_types.get(str(abstract_type_spec.label))
     if bound_type:
       return bound_type
     else:
       bound_abstract_types[str(abstract_type_spec.label)] = concrete_type_spec
       return concrete_type_spec
   elif abstract_type_spec.is_tensor():
     return abstract_type_spec
   elif abstract_type_spec.is_struct():
     if not concrete_type_spec.is_struct():
       raise TypeError(type_error_string)
     abstract_elements = structure.to_elements(abstract_type_spec)
     concrete_elements = structure.to_elements(concrete_type_spec)
     if len(abstract_elements) != len(concrete_elements):
       raise TypeError(type_error_string)
     concretized_tuple_elements = []
     for k in range(len(abstract_elements)):
       if abstract_elements[k][0] != concrete_elements[k][0]:
         raise TypeError(type_error_string)
       concretized_tuple_elements.append(
           (abstract_elements[k][0],
            _concretize_abstract_types(abstract_elements[k][1],
                                       concrete_elements[k][1])))
     return computation_types.StructType(concretized_tuple_elements)
   elif abstract_type_spec.is_sequence():
     if not concrete_type_spec.is_sequence():
       raise TypeError(type_error_string)
     return computation_types.SequenceType(
         _concretize_abstract_types(abstract_type_spec.element,
                                    concrete_type_spec.element))
   elif abstract_type_spec.is_function():
     if not concrete_type_spec.is_function():
       raise TypeError(type_error_string)
     if abstract_type_spec.parameter is None:
       if concrete_type_spec.parameter is not None:
         return TypeError(type_error_string)
       concretized_param = None
     else:
       concretized_param = _concretize_abstract_types(
           abstract_type_spec.parameter, concrete_type_spec.parameter)
     concretized_result = _concretize_abstract_types(abstract_type_spec.result,
                                                     concrete_type_spec.result)
     return computation_types.FunctionType(concretized_param,
                                           concretized_result)
   elif abstract_type_spec.is_placement():
     if not concrete_type_spec.is_placement():
       raise TypeError(type_error_string)
     return abstract_type_spec
   elif abstract_type_spec.is_federated():
     if not concrete_type_spec.is_federated():
       raise TypeError(type_error_string)
     new_member = _concretize_abstract_types(abstract_type_spec.member,
                                             concrete_type_spec.member)
     return computation_types.FederatedType(new_member,
                                            abstract_type_spec.placement,
                                            abstract_type_spec.all_equal)
   else:
     raise TypeError(
         'Unexpected abstract typespec {}.'.format(abstract_type_spec))
Ejemplo n.º 26
0
def create_binary_operator(
    operator, operand_type: computation_types.Type) -> ProtoAndType:
  """Returns a tensorflow computation computing a binary operation.

  The returned computation has the type signature `(<T,T> -> U)`, where `T` is
  `operand_type` and `U` is the result of applying the `operator` to a tuple of
  type `<T,T>`

  Note: If `operand_type` is a `computation_types.StructType`, then
  `operator` will be applied pointwise. This places the burden on callers of
  this function to construct the correct values to pass into the returned
  function. For example, to divide `[2, 2]` by `2`, first `2` must be packed
  into the data structure `[x, x]`, before the division operator of the
  appropriate type is called.

  Args:
    operator: A callable taking two arguments representing the operation to
      encode For example: `tf.math.add`, `tf.math.multiply`, and
        `tf.math.divide`.
    operand_type: A `computation_types.Type` to use as the argument to the
      constructed binary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `operand_type` are violated or `operator`
      is not callable.
  """
  if not type_analysis.is_generic_op_compatible_type(operand_type):
    raise TypeError(
        'The type {} contains a type other than `computation_types.TensorType` '
        'and `computation_types.StructType`; this is disallowed in the '
        'generic operators.'.format(operand_type))
  py_typecheck.check_callable(operator)
  with tf.Graph().as_default() as graph:
    operand_1_value, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', operand_type, graph)
    operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph(
        'y', operand_type, graph)

    if operand_type is not None:
      if operand_type.is_tensor():
        result_value = operator(operand_1_value, operand_2_value)
      elif operand_type.is_struct():
        result_value = structure.map_structure(operator, operand_1_value,
                                               operand_2_value)
    else:
      raise TypeError(
          'Operand type {} cannot be used in generic operations. The call to '
          '`type_analysis.is_generic_op_compatible_type` has allowed it to '
          'pass, and should be updated.'.format(operand_type))
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result_value, graph)

  type_signature = computation_types.FunctionType(
      computation_types.StructType((operand_type, operand_type)), result_type)
  parameter_binding = pb.TensorFlow.Binding(
      struct=pb.TensorFlow.StructBinding(
          element=[operand_1_binding, operand_2_binding]))
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
Ejemplo n.º 27
0
def _federated_same_placement(x: computation_types.Type,
                              y: computation_types.Type) -> bool:
    return x.is_federated() and y.is_federated() and x.placement == y.placement
Ejemplo n.º 28
0
  def _check_helper(generic_type_member: computation_types.Type,
                    concrete_type_member: computation_types.Type,
                    defining: bool):
    """Recursive helper function."""

    def _raise_structural(mismatch):
      raise MismatchedStructureError(concrete_type, generic_type,
                                     concrete_type_member, generic_type_member,
                                     mismatch)

    def _both_are(predicate):
      if predicate(generic_type_member):
        if predicate(concrete_type_member):
          return True
        else:
          _raise_structural('kind')
      else:
        return False

    if generic_type_member.is_abstract():
      label = str(generic_type_member.label)
      if not defining:
        non_defining_usages[label].append(concrete_type_member)
      else:
        bound_type = type_bindings.get(label)
        if bound_type is not None:
          if not concrete_type_member.is_equivalent_to(bound_type):
            raise MismatchedConcreteTypesError(concrete_type, generic_type,
                                               label, bound_type,
                                               concrete_type_member)
        else:
          type_bindings[label] = concrete_type_member
    elif _both_are(lambda t: t.is_tensor()):
      if generic_type_member != concrete_type_member:
        _raise_structural('tensor types')
    elif _both_are(lambda t: t.is_placement()):
      if generic_type_member != concrete_type_member:
        _raise_structural('placements')
    elif _both_are(lambda t: t.is_struct()):
      generic_elements = structure.to_elements(generic_type_member)
      concrete_elements = structure.to_elements(concrete_type_member)
      if len(generic_elements) != len(concrete_elements):
        _raise_structural('length')
      for k in range(len(generic_elements)):
        if generic_elements[k][0] != concrete_elements[k][0]:
          _raise_structural('element names')
        _check_helper(generic_elements[k][1], concrete_elements[k][1], defining)
    elif _both_are(lambda t: t.is_sequence()):
      _check_helper(generic_type_member.element, concrete_type_member.element,
                    defining)
    elif _both_are(lambda t: t.is_function()):
      if generic_type_member.parameter is None:
        if concrete_type_member.parameter is not None:
          _raise_structural('parameter')
      else:
        _check_helper(generic_type_member.parameter,
                      concrete_type_member.parameter, not defining)
      _check_helper(generic_type_member.result, concrete_type_member.result,
                    defining)
    elif _both_are(lambda t: t.is_federated()):
      if generic_type_member.placement != concrete_type_member.placement:
        _raise_structural('placement')
      if generic_type_member.all_equal != concrete_type_member.all_equal:
        _raise_structural('all equal')
      _check_helper(generic_type_member.member, concrete_type_member.member,
                    defining)
    else:
      raise TypeError(f'Unexpected type kind {generic_type}.')
Ejemplo n.º 29
0
def _deserialize_dataset_from_graph_def(serialized_graph_def: bytes,
                                        element_type: computation_types.Type):
    """Deserializes a serialized `tf.compat.v1.GraphDef` to a `tf.data.Dataset`.

  Args:
    serialized_graph_def: `bytes` object produced by
      `tensorflow_serialization.serialize_dataset`
    element_type: a `tff.Type` object representing the type structure of the
      elements yielded from the dataset.

  Returns:
    A `tf.data.Dataset` instance.
  """
    py_typecheck.check_type(element_type, computation_types.Type)
    type_analysis.check_tensorflow_compatible_type(element_type)

    def transform_to_tff_known_type(
        type_spec: computation_types.Type
    ) -> Tuple[computation_types.Type, bool]:
        """Transforms `StructType` to `StructWithPythonType`."""
        if type_spec.is_struct() and not type_spec.is_struct_with_python():
            field_is_named = tuple(
                name is not None
                for name, _ in structure.iter_elements(type_spec))
            has_names = any(field_is_named)
            is_all_named = all(field_is_named)
            if is_all_named:
                return computation_types.StructWithPythonType(
                    elements=structure.iter_elements(type_spec),
                    container_type=collections.OrderedDict), True
            elif not has_names:
                return computation_types.StructWithPythonType(
                    elements=structure.iter_elements(type_spec),
                    container_type=tuple), True
            else:
                raise TypeError(
                    'Cannot represent TFF type in TF because it contains '
                    f'partially named structures. Type: {type_spec}')
        return type_spec, False

    if element_type.is_struct():
        # TF doesn't suppor `structure.Strut` types, so we must transform the
        # `StructType` into a `StructWithPythonType` for use as the
        # `tf.data.Dataset.element_spec` later.
        tf_compatible_type, _ = type_transformations.transform_type_postorder(
            element_type, transform_to_tff_known_type)
    else:
        # We've checked this is only a struct or tensors, so we know this is a
        # `TensorType` here and will use as-is.
        tf_compatible_type = element_type

    def type_to_tensorspec(t: computation_types.TensorType) -> tf.TensorSpec:
        return tf.TensorSpec(shape=t.shape, dtype=t.dtype)

    element_spec = type_conversions.structure_from_tensor_type_tree(
        type_to_tensorspec, tf_compatible_type)
    ds = tf.data.experimental.from_variant(
        tf.raw_ops.DatasetFromGraph(graph_def=serialized_graph_def),
        structure=element_spec)
    # If a serialized dataset had elements of nested structes of tensors (e.g.
    # `dict`, `OrderedDict`), the deserialized dataset will return `dict`,
    # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion).
    #
    # Since the dataset will only be used inside TFF, we wrap the dictionary
    # coming from TF in an `OrderedDict` when necessary (a type that both TF and
    # TFF understand), using the field order stored in the TFF type stored during
    # serialization.
    return tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
        ds, tf_compatible_type)
Ejemplo n.º 30
0
def _preorder_types(type_signature: computation_types.Type):
  yield type_signature
  for child in type_signature.children():
    yield from _preorder_types(child)