Exemplo n.º 1
0
def _create_structure_of_coro_references(
    coro: Coroutine[Any, Any,
                    Any], type_signature: computation_types.Type) -> Any:
  """Returns a structure of `tff.program.CoroValueReference`s."""
  py_typecheck.check_type(type_signature, computation_types.Type)

  if type_signature.is_struct():

    async def _to_structure(coro: Coroutine[Any, Any, Any]) -> structure.Struct:
      return structure.from_container(await coro)

    coro = _to_structure(coro)
    shared_awaitable = async_utils.SharedAwaitable(coro)

    async def _get_item(awaitable: Awaitable[structure.Struct],
                        index: int) -> Any:
      value = await awaitable
      return value[index]

    elements = []
    element_types = structure.iter_elements(type_signature)
    for index, (name, element_type) in enumerate(element_types):
      element_coro = _get_item(shared_awaitable, index)
      element = _create_structure_of_coro_references(element_coro, element_type)
      elements.append((name, element))
    return structure.Struct(elements)
  elif (type_signature.is_federated() and
        type_signature.placement == placements.SERVER):
    return _create_structure_of_coro_references(coro, type_signature.member)
  elif type_signature.is_sequence():
    return CoroValueReference(coro, type_signature)
  elif type_signature.is_tensor():
    return CoroValueReference(coro, type_signature)
  else:
    raise NotImplementedError(f'Unexpected type found: {type_signature}.')
Exemplo n.º 2
0
def is_valid_bitwidth_type_for_value_type(
    bitwidth_type: computation_types.Type,
    value_type: computation_types.Type) -> bool:
  """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`."""

  py_typecheck.check_type(bitwidth_type, computation_types.Type)
  py_typecheck.check_type(value_type, computation_types.Type)

  if bitwidth_type.is_tensor():
    # This condition applies to both `value_type` being a tensor or structure,
    # as the same integer bitwidth can be used for all values in the structure.
    return bitwidth_type.dtype.is_integer and (
        bitwidth_type.shape.num_elements() == 1)
  elif value_type.is_struct() and bitwidth_type.is_struct():
    bitwidth_name_and_types = list(structure.iter_elements(bitwidth_type))
    value_name_and_types = list(structure.iter_elements(value_type))
    if len(bitwidth_name_and_types) != len(value_name_and_types):
      return False
    for (inner_bitwidth_name,
         inner_bitwidth_type), (inner_value_name, inner_value_type) in zip(
             bitwidth_name_and_types, value_name_and_types):
      if inner_bitwidth_name != inner_value_name:
        return False
      if not is_valid_bitwidth_type_for_value_type(inner_bitwidth_type,
                                                   inner_value_type):
        return False
    return True
  else:
    return False
Exemplo n.º 3
0
def is_single_integer_or_matches_structure(
    type_sig: computation_types.Type,
    shape_type: computation_types.Type) -> bool:
  """If `type_sig` is an integer or integer structure matching `shape_type`."""

  py_typecheck.check_type(type_sig, computation_types.Type)
  py_typecheck.check_type(shape_type, computation_types.Type)

  if type_sig.is_tensor():
    # This condition applies to both `shape_type` being a tensor or structure,
    # as the same integer bitwidth can be used for all values in the structure.
    return type_sig.dtype.is_integer and (type_sig.shape.num_elements() == 1)
  elif shape_type.is_struct() and type_sig.is_struct():
    bitwidth_name_and_types = list(structure.iter_elements(type_sig))
    shape_name_and_types = list(structure.iter_elements(shape_type))
    if len(type_sig) != len(shape_name_and_types):
      return False
    for (inner_name, type_sig), (inner_shape_name, inner_shape_type) in zip(
        bitwidth_name_and_types, shape_name_and_types):
      if inner_name != inner_shape_name:
        return False
      if not is_single_integer_or_matches_structure(type_sig, inner_shape_type):
        return False
    return True
  else:
    return False
Exemplo n.º 4
0
async def _materialize_structure_of_value_references(
    value: Any, type_signature: computation_types.Type) -> Any:
  """Returns a structure of materialized values."""
  py_typecheck.check_type(type_signature, computation_types.Type)

  async def _materialize(value: Any) -> Any:
    if isinstance(value, value_reference.MaterializableValueReference):
      return await value.get_value()
    else:
      return value

  if type_signature.is_struct():
    value = structure.from_container(value)
    element_types = list(structure.iter_elements(type_signature))
    element_coros = [
        _materialize_structure_of_value_references(v, t)
        for v, (_, t) in zip(value, element_types)
    ]
    elements = await asyncio.gather(*element_coros)
    elements = [(n, v) for v, (n, _) in zip(elements, element_types)]
    return structure.Struct(elements)
  elif (type_signature.is_federated() and
        type_signature.placement == placements.SERVER):
    return await _materialize_structure_of_value_references(
        value, type_signature.member)
  elif type_signature.is_sequence():
    return await _materialize(value)
  elif type_signature.is_tensor():
    return await _materialize(value)
  else:
    return value
Exemplo n.º 5
0
def _repackage_partitioned_values(after_merge_results,
                                  result_type_spec: computation_types.Type):
    """Inverts `_split_value_into_subrounds` above."""
    py_typecheck.check_type(after_merge_results, list)
    if result_type_spec.is_struct():
        after_merge_structs = [
            structure.from_container(x) for x in after_merge_results
        ]
        result_container = []
        for idx, (name, elem_type) in enumerate(
                structure.iter_elements(result_type_spec)):
            result_container.append(
                (name,
                 _repackage_partitioned_values(
                     [x[idx] for x in after_merge_structs], elem_type)))
        return structure.Struct(result_container)
    elif result_type_spec.is_federated(
    ) and result_type_spec.placement.is_clients():
        if result_type_spec.all_equal:
            return after_merge_results[0]
        for x in after_merge_results:
            py_typecheck.check_type(x, (list, tuple))
        # Merges all clients-placed values back together.
        return functools.reduce(lambda x, y: x + y, after_merge_results)
    else:
        return after_merge_results[0]
def create_identity(
        type_signature: computation_types.Type) -> ComputationProtoAndType:
    """Returns a tensorflow computation representing an identity function.

  The returned computation has the type signature `(T -> T)`, where `T` is
  `type_signature`. NOTE: if `T` contains `computation_types.StructType`s
  without an associated container type, they will be given the container type
  `tuple` by this function.

  Args:
    type_signature: A `computation_types.Type` to use as the parameter type and
      result type of the identity function.

  Raises:
    TypeError: If `type_signature` contains any types which cannot appear in
      TensorFlow bindings.
  """
    type_analysis.check_tensorflow_compatible_type(type_signature)
    parameter_type = type_signature
    if parameter_type is None:
        raise TypeError('TensorFlow identity cannot be created for NoneType.')

    # TF relies on feeds not-identical to fetches in certain circumstances.
    if type_signature.is_tensor() or type_signature.is_sequence():
        identity_fn = tf.identity
    elif type_signature.is_struct():
        identity_fn = functools.partial(structure.map_structure, tf.identity)
    else:
        raise NotImplementedError(
            f'TensorFlow identity cannot be created for type {type_signature}')

    return create_computation_for_py_fn(identity_fn, parameter_type)
Exemplo n.º 7
0
 def is_allowed_client_data_type(
         type_spec: computation_types.Type) -> bool:
     if type_spec.is_sequence():
         return type_analysis.is_tensorflow_compatible_type(
             type_spec.element)
     elif type_spec.is_struct():
         return all(
             is_allowed_client_data_type(element_type)
             for element_type in type_spec.children())
     else:
         return False
 def ingest_value(
     self, value: Any, type_signature: computation_types.Type
 ) -> executor_value_base.ExecutorValue:
     if type_signature is not None:
         if type_signature.is_federated():
             self._check_strategy_compatible_with_placement(
                 type_signature.placement)
         elif type_signature.is_function(
         ) and type_signature.result.is_federated():
             self._check_strategy_compatible_with_placement(
                 type_signature.result.placement)
     return FederatedResolvingStrategyValue(value, type_signature)
Exemplo n.º 9
0
def _to_sequence_internal_rep(
    *, value: Any, type_spec: computation_types.Type) -> tf.data.Dataset:
  """Ingests `value`, converting to an eager dataset."""
  if isinstance(value, list):
    value = tensorflow_utils.make_data_set_from_elements(
        None, value, type_spec.element)
  if isinstance(value, type_conversions.TF_DATASET_REPRESENTATION_TYPES):
    element_type = computation_types.to_type(value.element_spec)
    value_type = computation_types.SequenceType(element_type)
    type_spec.check_assignable_from(value_type)
    return value
  py_typecheck.check_type(type_spec, computation_types.SequenceType)
  output_sig = type_conversions.type_to_tf_tensor_specs(type_spec.element)
  return tf.data.Dataset.from_generator(value, output_signature=output_sig)
Exemplo n.º 10
0
def _ensure_deserialized_types_compatible(
        previous_type: Optional[computation_types.Type],
        next_type: computation_types.Type) -> computation_types.Type:
    """Ensures one of `previous_type` or `next_type` is assignable to the other.

  Returns the type which is assignable from the other.

  Args:
    previous_type: Instance of `computation_types.Type` or `None`.
    next_type: Instance of `computation_types.Type`.

  Returns:
    The supertype of `previous_type` and `next_type`.

  Raises:
    TypeError if neither type is assignable from the other.
  """
    if previous_type is None:
        return next_type
    else:
        if next_type.is_assignable_from(previous_type):
            return next_type
        elif previous_type.is_assignable_from(next_type):
            return previous_type
        raise TypeError(
            'Type mismatch checking member assignability under a '
            'federated value. Deserialized type {} is incompatible '
            'with previously deserialized {}.'.format(next_type,
                                                      previous_type))
Exemplo n.º 11
0
def _stamp_value_into_graph(value: Any, type_signature: computation_types.Type,
                            graph: tf.Graph) -> Any:
    """Stamps `value` in `graph` as an object of type `type_signature`.

  Args:
    value: A value to stamp.
    type_signature: The type of the value to stamp.
    graph: The graph to stamp in.

  Returns:
    A Python object made of tensors stamped into `graph`, `tf.data.Dataset`s,
    or `structure.Struct`s that structurally corresponds to the value passed at
    input.
  """
    if value is None:
        return None
    if type_signature.is_tensor():
        if isinstance(value, np.ndarray) or tf.is_tensor(value):
            value_type = computation_types.TensorType(
                tf.dtypes.as_dtype(value.dtype), tf.TensorShape(value.shape))
            type_signature.check_assignable_from(value_type)
            with graph.as_default():
                return tf.constant(value)
        else:
            with graph.as_default():
                return tf.constant(value,
                                   dtype=type_signature.dtype,
                                   shape=type_signature.shape)
    elif type_signature.is_struct():
        if isinstance(value, (list, dict)):
            value = structure.from_container(value)
        stamped_elements = []
        named_type_signatures = structure.to_elements(type_signature)
        for (name, type_signature), element in zip(named_type_signatures,
                                                   value):
            stamped_element = _stamp_value_into_graph(element, type_signature,
                                                      graph)
            stamped_elements.append((name, stamped_element))
        return structure.Struct(stamped_elements)
    elif type_signature.is_sequence():
        return tensorflow_utils.make_data_set_from_elements(
            graph, value, type_signature.element)
    else:
        raise NotImplementedError(
            'Unable to stamp a value of type {} in graph.'.format(
                type_signature))
Exemplo n.º 12
0
def type_to_tf_structure(type_spec: computation_types.Type):
  """Returns nested `tf.data.experimental.Structure` for a given TFF type.

  Args:
    type_spec: A `computation_types.Type`, the type specification must be
      composed of only named tuples and tensors. In all named tuples that appear
      in the type spec, all the elements must be named.

  Returns:
    An instance of `tf.data.experimental.Structure`, possibly nested, that
    corresponds to `type_spec`.

  Raises:
    ValueError: if the `type_spec` is composed of something other than named
      tuples and tensors, or if any of the elements in named tuples are unnamed.
  """
  py_typecheck.check_type(type_spec, computation_types.Type)
  if type_spec.is_tensor():
    return tf.TensorSpec(type_spec.shape, type_spec.dtype)
  elif type_spec.is_struct():
    elements = structure.to_elements(type_spec)
    if not elements:
      raise ValueError('Empty tuples are unsupported.')
    element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements]
    named = element_outputs[0][0] is not None
    if not all((e[0] is not None) == named for e in element_outputs):
      raise ValueError('Tuple elements inconsistently named.')
    if type_spec.python_container is None:
      if named:
        output = collections.OrderedDict(element_outputs)
      else:
        output = tuple(v for _, v in element_outputs)
    else:
      container_type = type_spec.python_container
      if (py_typecheck.is_named_tuple(container_type) or
          py_typecheck.is_attrs(container_type)):
        output = container_type(**dict(element_outputs))
      elif named:
        output = container_type(element_outputs)
      else:
        output = container_type(
            e if e[0] is not None else e[1] for e in element_outputs)
    return output
  else:
    raise ValueError('Unsupported type {}.'.format(
        py_typecheck.type_string(type(type_spec))))
Exemplo n.º 13
0
 def _pack_into_type(to_pack: tf.Tensor, type_spec: computation_types.Type):
   """Pack Tensor value `to_pack` into the nested structure `type_spec`."""
   if type_spec.is_struct():
     elem_iter = structure.iter_elements(type_spec)
     return structure.Struct([(elem_name, _pack_into_type(to_pack, elem_type))
                              for elem_name, elem_type in elem_iter])
   elif type_spec.is_tensor():
     value_tensor_type = type_conversions.type_from_tensors(to_pack)
     if type_spec.is_assignable_from(value_tensor_type):
       return to_pack
     elif not type_spec.shape.is_fully_defined():
       raise TypeError('Cannot generate TensorFlow creating binary operator '
                       'with first type not assignable from second, and '
                       'first type without fully defined shapes. First '
                       f'type contains an element of type: {type_spec}.\n'
                       f'Packing value {to_pack} into this type is '
                       'undefined.')
     return tf.cast(tf.broadcast_to(to_pack, type_spec.shape), type_spec.dtype)
Exemplo n.º 14
0
def _check_type_is_fn(
    target: computation_types.Type,
    name: str,
    err_fn: Callable[[str],
                     Exception] = compiler.MapReduceFormCompilationError,
):
    if not target.is_function():
        raise err_fn(f'Expected {name} to be a function, but {name} had type '
                     f'{target}.')
Exemplo n.º 15
0
def _partition_value(
        val: _PartitioningValue,
        type_signature: computation_types.Type) -> _PartitioningValue:
    """Partitions value as specified in _split_value_into_subrounds."""
    if type_signature.is_struct():
        struct_val = structure.from_container(val.payload)
        result_container = []
        for (_, val_elem), (name, type_elem) in zip(
                structure.iter_elements(struct_val),
                structure.iter_elements(type_signature)):
            partitioning_val_elem = _PartitioningValue(
                val_elem, val.num_remaining_clients,
                val.num_remaining_partitions, val.last_client_index)
            partition_result = _partition_value(partitioning_val_elem,
                                                type_elem)
            result_container.append((name, partition_result.payload))
        return _PartitioningValue(structure.Struct(result_container),
                                  partition_result.num_remaining_clients,
                                  partition_result.num_remaining_partitions,
                                  partition_result.last_client_index)
    elif (type_signature.is_federated()
          and type_signature.placement.is_clients()):

        if type_signature.all_equal:
            # In this case we simply replicate the argument for every subround.
            return val

        py_typecheck.check_type(val.payload, Sequence)
        num_clients_for_subround = math.ceil(val.num_remaining_clients /
                                             val.num_remaining_partitions)
        num_remaining_clients = val.num_remaining_clients - num_clients_for_subround
        num_remaining_partitions = val.num_remaining_partitions - 1
        values_to_return = val.payload[val.last_client_index:val.
                                       last_client_index +
                                       num_clients_for_subround]
        last_client_index = val.last_client_index + num_clients_for_subround
        return _PartitioningValue(
            payload=values_to_return,
            num_remaining_clients=num_remaining_clients,
            num_remaining_partitions=num_remaining_partitions,
            last_client_index=last_client_index)
    else:
        return val
Exemplo n.º 16
0
 def transform_to_tff_known_type(
     type_spec: computation_types.Type
 ) -> Tuple[computation_types.Type, bool]:
     """Transforms `StructType` to `StructWithPythonType`."""
     if type_spec.is_struct() and not type_spec.is_struct_with_python():
         field_is_named = tuple(
             name is not None
             for name, _ in structure.iter_elements(type_spec))
         has_names = any(field_is_named)
         is_all_named = all(field_is_named)
         if is_all_named:
             return computation_types.StructWithPythonType(
                 elements=structure.iter_elements(type_spec),
                 container_type=collections.OrderedDict), True
         elif not has_names:
             return computation_types.StructWithPythonType(
                 elements=structure.iter_elements(type_spec),
                 container_type=tuple), True
         else:
             raise TypeError(
                 'Cannot represent TFF type in TF because it contains '
                 f'partially named structures. Type: {type_spec}')
     return type_spec, False
Exemplo n.º 17
0
def _unique_dtypes_in_structure(
        type_spec: computation_types.Type) -> Set[tf.DType]:
    """Returns a set of unique dtypes in `type_spec`.

  Args:
    type_spec: A `computation_types.Type`.

  Returns:
    A `set` containing unique dtypes found in `type_spec`.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        py_typecheck.check_type(type_spec.dtype, tf.DType)
        return set([type_spec.dtype])
    elif type_spec.is_struct():
        return set(
            tf.nest.flatten(
                type_conversions.structure_from_tensor_type_tree(
                    lambda t: t.dtype, type_spec)))
    elif type_spec.is_federated():
        return _unique_dtypes_in_structure(type_spec.member)
    else:
        return set()
Exemplo n.º 18
0
def is_binary_op_with_upcast_compatible_pair(
    possibly_nested_type: Optional[computation_types.Type],
    type_to_upcast: computation_types.Type) -> bool:
  """Checks unambiguity in applying `type_to_upcast` to `possibly_nested_type`.

  That is, checks that either these types are equivalent and contain only
  tuples and tensors, or that
  `possibly_nested_type` is perhaps a nested structure containing only tensors
  with `dtype` of `type_to_upcast` at the leaves, where `type_to_upcast` must
  be a scalar tensor type. Notice that this relationship is not symmetric,
  since binary operators need not respect this symmetry in general.
  For example, it makes perfect sence to divide a nested structure of tensors
  by a scalar, but not the other way around.

  Args:
    possibly_nested_type: A `computation_types.Type`, or `None`.
    type_to_upcast: A `computation_types.Type`, or `None`.

  Returns:
    Boolean indicating whether `type_to_upcast` can be upcast to
    `possibly_nested_type` in the manner described above.
  """
  if possibly_nested_type is not None:
    py_typecheck.check_type(possibly_nested_type, computation_types.Type)
  if type_to_upcast is not None:
    py_typecheck.check_type(type_to_upcast, computation_types.Type)
  if not (is_generic_op_compatible_type(possibly_nested_type) and
          is_generic_op_compatible_type(type_to_upcast)):
    return False
  if possibly_nested_type is None:
    return type_to_upcast is None
  if possibly_nested_type.is_equivalent_to(type_to_upcast):
    return True
  if not (type_to_upcast.is_tensor() and type_to_upcast.shape == tf.TensorShape(
      ())):
    return False

  types_are_ok = [True]

  only_allowed_dtype = type_to_upcast.dtype

  def _check_tensor_types(type_spec):
    if type_spec.is_tensor() and type_spec.dtype != only_allowed_dtype:
      types_are_ok[0] = False
    return type_spec, False

  type_transformations.transform_type_postorder(possibly_nested_type,
                                                _check_tensor_types)

  return types_are_ok[0]
Exemplo n.º 19
0
def is_average_compatible(type_spec: computation_types.Type) -> bool:
  """Determines if `type_spec` can be averaged.

  Types that are average-compatible are composed of numeric tensor types,
  either floating-point or complex, possibly packaged into nested named tuples,
  and possibly federated.

  Args:
    type_spec: a `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is average-compatible, `False` otherwise.
  """
  py_typecheck.check_type(type_spec, computation_types.Type)
  if type_spec.is_tensor():
    return type_spec.dtype.is_floating or type_spec.dtype.is_complex
  elif type_spec.is_struct():
    return all(
        is_average_compatible(v) for _, v in structure.iter_elements(type_spec))
  elif type_spec.is_federated():
    return is_average_compatible(type_spec.member)
  else:
    return False
Exemplo n.º 20
0
def is_structure_of_integers(type_spec: computation_types.Type) -> bool:
  """Determines if `type_spec` is a structure of integers.

  Note that an empty `computation_types.StructType` will return `True`, as it
  does not contain any non-integer types.

  Args:
    type_spec: A `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is a structure of integers, otherwise `False`.
  """
  py_typecheck.check_type(type_spec, computation_types.Type)
  if type_spec.is_tensor():
    py_typecheck.check_type(type_spec.dtype, tf.DType)
    return type_spec.dtype.is_integer
  elif type_spec.is_struct():
    return all(
        is_structure_of_integers(v)
        for _, v in structure.iter_elements(type_spec))
  elif type_spec.is_federated():
    return is_structure_of_integers(type_spec.member)
  else:
    return False
Exemplo n.º 21
0
def _to_tensor_internal_rep(*, value: Any,
                            type_spec: computation_types.Type) -> tf.Tensor:
  """Normalizes tensor-like value to a tf.Tensor."""
  if not tf.is_tensor(value):
    value = tf.convert_to_tensor(value, dtype=type_spec.dtype)
  elif hasattr(value, 'read_value'):
    # a tf.Variable-like result, get a proper tensor.
    value = value.read_value()
  value_type = (
      computation_types.TensorType(value.dtype.base_dtype, value.shape))
  if not type_spec.is_assignable_from(value_type):
    raise TypeError(
        'The apparent type {} of a tensor {} does not match the expected '
        'type {}.'.format(value_type, value, type_spec))
  return value
Exemplo n.º 22
0
def check_type(value: Any, type_spec: computation_types.Type):
  """Checks whether `val` is of TFF type `type_spec`.

  Args:
    value: The object to check.
    type_spec: A `computation_types.Type`, the type that `value` is checked
      against.

  Raises:
    TypeError: If the infferred type of `value` is not `type_spec`.
  """
  py_typecheck.check_type(type_spec, computation_types.Type)
  value_type = type_conversions.infer_type(value)
  if not type_spec.is_assignable_from(value_type):
    raise TypeError(
        'Expected TFF type {}, which is not assignable from {}.'.format(
            type_spec, value_type))
Exemplo n.º 23
0
def visit_preorder(type_signature: computation_types.Type,
                   fn: Callable[[computation_types.Type, T], T], context: T):
    """Recursively calls `fn` on the possibly nested structure `type_signature`.

  Walks the tree in a preorder manner. Updates `context` on the way down with
  the appropriate information, as defined in `fn`.

  Args:
    type_signature: A `computation_types.Type`.
    fn: A function to apply to each of the constituent elements of
      `type_signature` with the argument `context`. Must return an updated
      version of `context` which incorporated the information we'd like to track
      as we move down the type tree.
    context: Initial state of information to be passed down the tree.
  """
    context = fn(type_signature, context)
    for child_type in type_signature.children():
        visit_preorder(child_type, fn, context)
Exemplo n.º 24
0
def check_type(value: Any, type_spec: computation_types.Type):
  """Checks whether `val` is of TFF type `type_spec`.

  Args:
    value: The object to check.
    type_spec: A `computation_types.Type`, the type that `value` is checked
      against.

  Raises:
    TypeError: If the inferred type of `value` is not assignable to `type_spec`.
  """
  py_typecheck.check_type(type_spec, computation_types.Type)
  value_type = type_conversions.infer_type(value)
  if not type_spec.is_assignable_from(value_type):
    raise TypeError(
        computation_types.type_mismatch_error_message(
            value_type,
            type_spec,
            computation_types.TypeRelation.ASSIGNABLE,
            second_is_expected=True))
Exemplo n.º 25
0
def reconcile_value_type_with_type_spec(
        value_type: computation_types.Type,
        type_spec: Optional[computation_types.Type]) -> computation_types.Type:
    """Reconciles a pair of types.

  Args:
    value_type: An instance of `tff.Type`.
    type_spec: An instance of `tff.Type`, or `None`.

  Returns:
    Either `value_type` if `type_spec` is `None`, or `type_spec` if `type_spec`
    is not `None` and rquivalent with `value_type`.

  Raises:
    TypeError: If arguments are of incompatible types.
  """
    py_typecheck.check_type(value_type, computation_types.Type)
    if type_spec is not None:
        py_typecheck.check_type(value_type, computation_types.Type)
        if not value_type.is_equivalent_to(type_spec):
            raise TypeError('Expected a value of type {}, found {}.'.format(
                type_spec, value_type))
    return type_spec if type_spec is not None else value_type
Exemplo n.º 26
0
def _is_two_tuple(t: computation_types.Type) -> bool:
    return t.is_struct() and len(t) == 2
Exemplo n.º 27
0
 def federated_empty_struct(type_spec: computation_types.Type) -> bool:
     return type_spec.is_struct() or type_spec.is_federated()
def create_constant(
        value, type_spec: computation_types.Type) -> ComputationProtoAndType:
    """Returns a tensorflow computation returning a constant `value`.

  The returned computation has the type signature `( -> T)`, where `T` is
  `type_spec`.

  `value` must be a value convertible to a tensor or a structure of values, such
  that the dtype and shapes match `type_spec`. `type_spec` must contain only
  named tuples and tensor types, but these can be arbitrarily nested.

  Args:
    value: A value to embed as a constant in the tensorflow graph.
    type_spec: A `computation_types.Type` to use as the argument to the
      constructed binary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `type_spec` are violated.
  """
    if not type_analysis.is_generic_op_compatible_type(type_spec):
        raise TypeError(
            'Type spec {} cannot be constructed as a TensorFlow constant in TFF; '
            ' only nested tuples and tensors are permitted.'.format(type_spec))
    inferred_value_type = type_conversions.infer_type(value)
    if (inferred_value_type.is_struct()
            and not type_spec.is_assignable_from(inferred_value_type)):
        raise TypeError(
            'Must pass a only tensor or structure of tensor values to '
            '`create_tensorflow_constant`; encountered a value {v} with inferred '
            'type {t!r}, but needed {s!r}'.format(v=value,
                                                  t=inferred_value_type,
                                                  s=type_spec))
    if inferred_value_type.is_struct():
        value = structure.from_container(value, recursive=True)
    tensor_dtypes_in_type_spec = []

    def _pack_dtypes(type_signature):
        """Appends dtype of `type_signature` to nonlocal variable."""
        if type_signature.is_tensor():
            tensor_dtypes_in_type_spec.append(type_signature.dtype)
        return type_signature, False

    type_transformations.transform_type_postorder(type_spec, _pack_dtypes)

    if (any(x.is_integer for x in tensor_dtypes_in_type_spec)
            and (inferred_value_type.is_tensor()
                 and not inferred_value_type.dtype.is_integer)):
        raise TypeError(
            'Only integers can be used as scalar values if our desired constant '
            'type spec contains any integer tensors; passed scalar {} of dtype {} '
            'for type spec {}.'.format(value, inferred_value_type.dtype,
                                       type_spec))

    result_type = type_spec

    def _create_result_tensor(type_spec, value):
        """Packs `value` into `type_spec` recursively."""
        if type_spec.is_tensor():
            type_spec.shape.assert_is_fully_defined()
            result = tf.constant(value,
                                 dtype=type_spec.dtype,
                                 shape=type_spec.shape)
        else:
            elements = []
            if inferred_value_type.is_struct():
                # Copy the leaf values according to the type_spec structure.
                for (name, elem_type), value in zip(
                        structure.iter_elements(type_spec), value):
                    elements.append(
                        (name, _create_result_tensor(elem_type, value)))
            else:
                # "Broadcast" the value to each level of the type_spec structure.
                for _, elem_type in structure.iter_elements(type_spec):
                    elements.append(
                        (None, _create_result_tensor(elem_type, value)))
            result = structure.Struct(elements)
        return result

    with tf.Graph().as_default() as graph:
        result = _create_result_tensor(result_type, value)
        _, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(None, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
Exemplo n.º 29
0
def type_to_tf_structure(type_spec: computation_types.Type):
    """Returns nested `tf.data.experimental.Structure` for a given TFF type.

  Args:
    type_spec: A `computation_types.Type`, the type specification must be
      composed of only named tuples and tensors. In all named tuples that appear
      in the type spec, all the elements must be named.

  Returns:
    An instance of `tf.data.experimental.Structure`, possibly nested, that
    corresponds to `type_spec`.

  Raises:
    ValueError: if the `type_spec` is composed of something other than named
      tuples and tensors, or if any of the elements in named tuples are unnamed.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return tf.TensorSpec(type_spec.shape, type_spec.dtype)
    elif type_spec.is_struct():
        elements = structure.to_elements(type_spec)
        if not elements:
            return ()
        element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements]
        named = element_outputs[0][0] is not None
        if not all((e[0] is not None) == named for e in element_outputs):
            raise ValueError('Tuple elements inconsistently named.')
        if type_spec.python_container is None:
            if named:
                return collections.OrderedDict(element_outputs)
            else:
                return tuple(v for _, v in element_outputs)
        else:
            container_type = type_spec.python_container
            if (py_typecheck.is_named_tuple(container_type)
                    or py_typecheck.is_attrs(container_type)):
                return container_type(**dict(element_outputs))
            elif container_type is tf.RaggedTensor:
                flat_values = type_spec.flat_values
                nested_row_splits = type_spec.nested_row_splits
                ragged_rank = len(nested_row_splits)
                return tf.RaggedTensorSpec(
                    shape=tf.TensorShape([None] * (ragged_rank + 1)),
                    dtype=flat_values.dtype,
                    ragged_rank=ragged_rank,
                    row_splits_dtype=nested_row_splits[0].dtype,
                    flat_values_spec=None)
            elif container_type is tf.SparseTensor:
                # We can't generally infer the shape from the type of the tensors, but
                # we *can* infer the rank based on the shapes of `indices` or
                # `dense_shape`.
                if (type_spec.indices.shape is not None
                        and type_spec.indices.shape.dims[1] is not None):
                    rank = type_spec.indices.shape.dims[1]
                    shape = tf.TensorShape([None] * rank)
                elif (type_spec.dense_shape.shape is not None
                      and type_spec.dense_shape.shape.dims[0] is not None):
                    rank = type_spec.dense_shape.shape.dims[0]
                    shape = tf.TensorShape([None] * rank)
                else:
                    shape = None
                return tf.SparseTensorSpec(shape=shape,
                                           dtype=type_spec.values.dtype)
            elif named:
                return container_type(element_outputs)
            else:
                return container_type(e if e[0] is not None else e[1]
                                      for e in element_outputs)
    else:
        raise ValueError('Unsupported type {}.'.format(
            py_typecheck.type_string(type(type_spec))))
Exemplo n.º 30
0
def transform_type_postorder(
    type_signature: computation_types.Type,
    transform_fn: Callable[[computation_types.Type],
                           Tuple[computation_types.Type, bool]]):
    """Walks type tree of `type_signature` postorder, calling `transform_fn`.

  Args:
    type_signature: Instance of `computation_types.Type` to transform
      recursively.
    transform_fn: Transformation function to apply to each node in the type tree
      of `type_signature`. Must be instance of Python function type.

  Returns:
    A possibly transformed version of `type_signature`, with each node in its
    tree the result of applying `transform_fn` to the corresponding node in
    `type_signature`.

  Raises:
    TypeError: If the types don't match the specification above.
  """
    py_typecheck.check_type(type_signature, computation_types.Type)
    py_typecheck.check_callable(transform_fn)
    if type_signature.is_federated():
        transformed_member, member_mutated = transform_type_postorder(
            type_signature.member, transform_fn)
        if member_mutated:
            type_signature = computation_types.FederatedType(
                transformed_member, type_signature.placement,
                type_signature.all_equal)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or member_mutated
    elif type_signature.is_sequence():
        transformed_element, element_mutated = transform_type_postorder(
            type_signature.element, transform_fn)
        if element_mutated:
            type_signature = computation_types.SequenceType(
                transformed_element)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or element_mutated
    elif type_signature.is_function():
        if type_signature.parameter is not None:
            transformed_parameter, parameter_mutated = transform_type_postorder(
                type_signature.parameter, transform_fn)
        else:
            transformed_parameter, parameter_mutated = (None, False)
        transformed_result, result_mutated = transform_type_postorder(
            type_signature.result, transform_fn)
        if parameter_mutated or result_mutated:
            type_signature = computation_types.FunctionType(
                transformed_parameter, transformed_result)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, (type_signature_mutated or parameter_mutated
                                or result_mutated)
    elif type_signature.is_struct():
        elements = []
        elements_mutated = False
        for element in structure.iter_elements(type_signature):
            transformed_element, element_mutated = transform_type_postorder(
                element[1], transform_fn)
            elements_mutated = elements_mutated or element_mutated
            elements.append((element[0], transformed_element))
        if elements_mutated:
            if type_signature.is_struct_with_python():
                type_signature = computation_types.StructWithPythonType(
                    elements, type_signature.python_container)
            else:
                type_signature = computation_types.StructType(elements)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or elements_mutated
    elif type_signature.is_abstract() or type_signature.is_placement(
    ) or type_signature.is_tensor():
        return transform_fn(type_signature)