def create_unary_operator(
        operator,
        operand_type: computation_types.Type) -> ComputationProtoAndType:
    """Returns a tensorflow computation computing a unary operation.

  The returned computation has the type signature `(T -> U)`, where `T` is
  `operand_type` and `U` is the result of applying the `operator` to a value of
  type `T`

  Args:
    operator: A callable taking one argument representing the operation to
      encode For example: `tf.math.abs`.
    operand_type: A `computation_types.Type` to use as the argument to the
      constructed unary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `operand_type` are violated or `operator`
      is not callable.
  """
    if (operand_type is None
            or not type_analysis.is_generic_op_compatible_type(operand_type)):
        raise TypeError(
            '`operand_type` contains a type other than '
            '`computation_types.TensorType` and `computation_types.StructType`; '
            f'this is disallowed in the generic operators. Got: {operand_type} '
        )
    py_typecheck.check_callable(operator)

    with tf.Graph().as_default() as graph:
        operand_value, operand_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', operand_type, graph)
        result_value = operator(operand_value)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result_value, graph)

    type_signature = computation_types.FunctionType(operand_type, result_type)
    parameter_binding = operand_binding
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
Esempio n. 2
0
def create_constant(scalar_value,
                    type_spec: computation_types.Type) -> ProtoAndType:
  """Returns a tensorflow computation returning a constant `scalar_value`.

  The returned computation has the type signature `( -> T)`, where `T` is
  `type_spec`.

  `scalar_value` must be a scalar, and cannot be a float if any of the tensor
  leaves of `type_spec` contain an integer data type. `type_spec` must contain
  only named tuples and tensor types, but these can be arbitrarily nested.

  Args:
    scalar_value: A scalar value to place in all the tensor leaves of
      `type_spec`.
    type_spec: A `computation_types.Type` to use as the argument to the
      constructed binary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `type_spec` are violated.
  """
  if not type_analysis.is_generic_op_compatible_type(type_spec):
    raise TypeError(
        'Type spec {} cannot be constructed as a TensorFlow constant in TFF; '
        ' only nested tuples and tensors are permitted.'.format(type_spec))
  inferred_scalar_value_type = type_conversions.infer_type(scalar_value)
  if (not inferred_scalar_value_type.is_tensor() or
      inferred_scalar_value_type.shape != tf.TensorShape(())):
    raise TypeError(
        'Must pass a scalar value to `create_tensorflow_constant`; encountered '
        'a value {}'.format(scalar_value))
  tensor_dtypes_in_type_spec = []

  def _pack_dtypes(type_signature):
    """Appends dtype of `type_signature` to nonlocal variable."""
    if type_signature.is_tensor():
      tensor_dtypes_in_type_spec.append(type_signature.dtype)
    return type_signature, False

  type_transformations.transform_type_postorder(type_spec, _pack_dtypes)

  if (any(x.is_integer for x in tensor_dtypes_in_type_spec) and
      not inferred_scalar_value_type.dtype.is_integer):
    raise TypeError(
        'Only integers can be used as scalar values if our desired constant '
        'type spec contains any integer tensors; passed scalar {} of dtype {} '
        'for type spec {}.'.format(scalar_value,
                                   inferred_scalar_value_type.dtype, type_spec))

  result_type = type_spec

  def _create_result_tensor(type_spec, scalar_value):
    """Packs `scalar_value` into `type_spec` recursively."""
    if type_spec.is_tensor():
      type_spec.shape.assert_is_fully_defined()
      result = tf.constant(
          scalar_value, dtype=type_spec.dtype, shape=type_spec.shape)
    else:
      elements = []
      for _, type_element in structure.iter_elements(type_spec):
        elements.append(_create_result_tensor(type_element, scalar_value))
      result = elements
    return result

  with tf.Graph().as_default() as graph:
    result = _create_result_tensor(result_type, scalar_value)
    _, result_binding = tensorflow_utils.capture_result_from_graph(
        result, graph)

  type_signature = computation_types.FunctionType(None, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=None,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
Esempio n. 3
0
def create_binary_operator(
    operator, operand_type: computation_types.Type) -> ProtoAndType:
  """Returns a tensorflow computation computing a binary operation.

  The returned computation has the type signature `(<T,T> -> U)`, where `T` is
  `operand_type` and `U` is the result of applying the `operator` to a tuple of
  type `<T,T>`

  Note: If `operand_type` is a `computation_types.StructType`, then
  `operator` will be applied pointwise. This places the burden on callers of
  this function to construct the correct values to pass into the returned
  function. For example, to divide `[2, 2]` by `2`, first `2` must be packed
  into the data structure `[x, x]`, before the division operator of the
  appropriate type is called.

  Args:
    operator: A callable taking two arguments representing the operation to
      encode For example: `tf.math.add`, `tf.math.multiply`, and
        `tf.math.divide`.
    operand_type: A `computation_types.Type` to use as the argument to the
      constructed binary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `operand_type` are violated or `operator`
      is not callable.
  """
  if not type_analysis.is_generic_op_compatible_type(operand_type):
    raise TypeError(
        'The type {} contains a type other than `computation_types.TensorType` '
        'and `computation_types.StructType`; this is disallowed in the '
        'generic operators.'.format(operand_type))
  py_typecheck.check_callable(operator)
  with tf.Graph().as_default() as graph:
    operand_1_value, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', operand_type, graph)
    operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph(
        'y', operand_type, graph)

    if operand_type is not None:
      if operand_type.is_tensor():
        result_value = operator(operand_1_value, operand_2_value)
      elif operand_type.is_struct():
        result_value = structure.map_structure(operator, operand_1_value,
                                               operand_2_value)
    else:
      raise TypeError(
          'Operand type {} cannot be used in generic operations. The call to '
          '`type_analysis.is_generic_op_compatible_type` has allowed it to '
          'pass, and should be updated.'.format(operand_type))
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result_value, graph)

  type_signature = computation_types.FunctionType(
      computation_types.StructType((operand_type, operand_type)), result_type)
  parameter_binding = pb.TensorFlow.Binding(
      struct=pb.TensorFlow.StructBinding(
          element=[operand_1_binding, operand_2_binding]))
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
def create_constant(
        value, type_spec: computation_types.Type) -> ComputationProtoAndType:
    """Returns a tensorflow computation returning a constant `value`.

  The returned computation has the type signature `( -> T)`, where `T` is
  `type_spec`.

  `value` must be a value convertible to a tensor or a structure of values, such
  that the dtype and shapes match `type_spec`. `type_spec` must contain only
  named tuples and tensor types, but these can be arbitrarily nested.

  Args:
    value: A value to embed as a constant in the tensorflow graph.
    type_spec: A `computation_types.Type` to use as the argument to the
      constructed binary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `type_spec` are violated.
  """
    if not type_analysis.is_generic_op_compatible_type(type_spec):
        raise TypeError(
            'Type spec {} cannot be constructed as a TensorFlow constant in TFF; '
            ' only nested tuples and tensors are permitted.'.format(type_spec))
    inferred_value_type = type_conversions.infer_type(value)
    if (inferred_value_type.is_struct()
            and not type_spec.is_assignable_from(inferred_value_type)):
        raise TypeError(
            'Must pass a only tensor or structure of tensor values to '
            '`create_tensorflow_constant`; encountered a value {v} with inferred '
            'type {t!r}, but needed {s!r}'.format(v=value,
                                                  t=inferred_value_type,
                                                  s=type_spec))
    if inferred_value_type.is_struct():
        value = structure.from_container(value, recursive=True)
    tensor_dtypes_in_type_spec = []

    def _pack_dtypes(type_signature):
        """Appends dtype of `type_signature` to nonlocal variable."""
        if type_signature.is_tensor():
            tensor_dtypes_in_type_spec.append(type_signature.dtype)
        return type_signature, False

    type_transformations.transform_type_postorder(type_spec, _pack_dtypes)

    if (any(x.is_integer for x in tensor_dtypes_in_type_spec)
            and (inferred_value_type.is_tensor()
                 and not inferred_value_type.dtype.is_integer)):
        raise TypeError(
            'Only integers can be used as scalar values if our desired constant '
            'type spec contains any integer tensors; passed scalar {} of dtype {} '
            'for type spec {}.'.format(value, inferred_value_type.dtype,
                                       type_spec))

    result_type = type_spec

    def _create_result_tensor(type_spec, value):
        """Packs `value` into `type_spec` recursively."""
        if type_spec.is_tensor():
            type_spec.shape.assert_is_fully_defined()
            result = tf.constant(value,
                                 dtype=type_spec.dtype,
                                 shape=type_spec.shape)
        else:
            elements = []
            if inferred_value_type.is_struct():
                # Copy the leaf values according to the type_spec structure.
                for (name, elem_type), value in zip(
                        structure.iter_elements(type_spec), value):
                    elements.append(
                        (name, _create_result_tensor(elem_type, value)))
            else:
                # "Broadcast" the value to each level of the type_spec structure.
                for _, elem_type in structure.iter_elements(type_spec):
                    elements.append(
                        (None, _create_result_tensor(elem_type, value)))
            result = structure.Struct(elements)
        return result

    with tf.Graph().as_default() as graph:
        result = _create_result_tensor(result_type, value)
        _, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(None, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)