def create_broadcast_scalar_to_shape(scalar_type: tf.DType,
                                     shape: tf.TensorShape) -> pb.Computation:
  """Returns a tensorflow computation returning the result of `tf.broadcast_to`.

  The returned computation has the type signature `(T -> U)`, where
  `T` is `scalar_type` and the `U` is a `tff.TensorType` with a dtype of
  `scalar_type` and a `shape`.

  Args:
    scalar_type: A `tf.DType`, the type of the scalar to broadcast.
    shape: A `tf.TensorShape` to broadcast to. Must be fully defined.

  Raises:
    TypeError: If `scalar_type` is not a `tf.DType` or if `shape` is not a
      `tf.TensorShape`.
    ValueError: If `shape` is not fully defined.
  """
  py_typecheck.check_type(scalar_type, tf.DType)
  py_typecheck.check_type(shape, tf.TensorShape)
  shape.assert_is_fully_defined()
  parameter_type = computation_types.TensorType(scalar_type, shape=())

  with tf.Graph().as_default() as graph:
    parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', parameter_type, graph)
    result = tf.broadcast_to(parameter_value, shape)
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result, graph)

  type_signature = computation_types.FunctionType(parameter_type, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature),
      tensorflow=tensorflow)
Exemple #2
0
    def test_counts_correct_number_of_ops_with_function(self):
        @tf.function
        def add_one(x):
            return x + 1

        with tf.Graph().as_default() as graph:
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                'x', tf.int32, graph)
            result = add_one(add_one(parameter_value))

        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)
        type_signature = computation_types.FunctionType(tf.int32, result_type)
        tensorflow = pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
            parameter=parameter_binding,
            result=result_binding)
        proto = pb.Computation(
            type=type_serialization.serialize_type(type_signature),
            tensorflow=tensorflow)
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)

        tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in(
            building_block)

        # Expect 7 ops:
        #    Inside the tf.function:
        #      - one constant
        #      - one addition
        #      - one identity on the result
        #    Inside the tff_computation:
        #      - one placeholders (one for the argument)
        #      - two partition calls
        #      - one identity on the tff_computation result
        self.assertEqual(tf_ops_in_graph, 7)
Exemple #3
0
def create_computation_for_py_fn(
        fn: types.FunctionType,
        parameter_type: Optional[computation_types.Type]) -> pb.Computation:
    """Returns a tensorflow computation returning the result of `fn`.

  The returned computation has the type signature `(T -> U)`, where `T` is
  `parameter_type` and `U` is the type returned by `fn`.

  Args:
    fn: A Python function.
    parameter_type: A `computation_types.Type`.
  """
    py_typecheck.check_type(fn, types.FunctionType)
    if parameter_type is not None:
        py_typecheck.check_type(parameter_type, computation_types.Type)

    with tf.Graph().as_default() as graph:
        if parameter_type is not None:
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                'x', parameter_type, graph)
            result = fn(parameter_value)
        else:
            parameter_binding = None
            result = fn()
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
Exemple #4
0
def tf_computation_serializer(parameter_type: Optional[computation_types.Type],
                              context_stack):
    """Serializes a TF computation with a given parameter type.

  Args:
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `computation_types.Type`.
    context_stack: The context stack to use.

  Yields:
    The first yielded value will be a Python object (such as a dataset,
    a placeholder, or a `structure.Struct`) to be passed to the function to
    serialize. The result of the function should then be passed to the
    following `send` call.
    The next yielded value will be
    a tuple of (`pb.Computation`, `tff.Type`), where the computation contains
    the instance with the `pb.TensorFlow` variant set, and the type is an
    instance of `tff.Type`, potentially including Python container annotations,
    for use by TensorFlow computation wrappers.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
    # TODO(b/113112108): Support a greater variety of target type signatures,
    # with keyword args or multiple args corresponding to elements of a tuple.
    # Document all accepted forms with examples in the API, and point to there
    # from here.

    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if parameter_type is not None:
        py_typecheck.check_type(parameter_type, computation_types.Type)

    with tf.Graph().as_default() as graph:
        if parameter_type is not None:
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                'arg', parameter_type, graph)
        else:
            parameter_value = None
            parameter_binding = None
        context = tensorflow_computation_context.TensorFlowComputationContext(
            graph)
        with context_stack.install(context):
            with variable_utils.record_variable_creation_scope(
            ) as all_variables:
                result = yield parameter_value
            initializer_ops = []
            if all_variables:
                # Use a readable but not-too-long name for the init_op.
                name = 'init_op_for_' + '_'.join(
                    [v.name.replace(':0', '') for v in all_variables])
                if len(name) > 50:
                    name = 'init_op_for_{}_variables'.format(
                        len(all_variables))
                initializer_ops.append(
                    tf.compat.v1.initializers.variables(all_variables,
                                                        name=name))
            initializer_ops.extend(
                tf.compat.v1.get_collection(
                    tf.compat.v1.GraphKeys.TABLE_INITIALIZERS))
            if initializer_ops:
                # Before running the main new init op, run any initializers for sub-
                # computations from context.init_ops. Variables from import_graph_def
                # will not make it into the global collections, and so will not be
                # initialized without this code path.
                with tf.compat.v1.control_dependencies(context.init_ops):
                    init_op_name = tf.group(*initializer_ops,
                                            name='grouped_initializers').name
            elif context.init_ops:
                init_op_name = tf.group(*context.init_ops,
                                        name='subcomputation_init_ops').name
            else:
                init_op_name = None

        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)

    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding,
                               initialize_op=init_op_name)
    yield pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow), type_signature
Exemple #5
0
def create_binary_operator_with_upcast(
    type_signature: computation_types.StructType,
    operator: Callable[[Any, Any], Any]) -> ProtoAndType:
  """Creates TF computation upcasting its argument and applying `operator`.

  Args:
    type_signature: A `computation_types.StructType` with two elements, both of
      the same type or the second able to be upcast to the first, as explained
      in `apply_binary_operator_with_upcast`, and both containing only tuples
      and tensors in their type tree.
    operator: Callable defining the operator.

  Returns:
    A `building_blocks.CompiledComputation` encapsulating a function which
    upcasts the second element of its argument and applies the binary
    operator.
  """
  py_typecheck.check_type(type_signature, computation_types.StructType)
  py_typecheck.check_callable(operator)
  type_analysis.check_tensorflow_compatible_type(type_signature)
  if not type_signature.is_struct() or len(type_signature) != 2:
    raise TypeError('To apply a binary operator, we must by definition have an '
                    'argument which is a `StructType` with 2 elements; '
                    'asked to create a binary operator for type: {t}'.format(
                        t=type_signature))
  if type_analysis.contains(type_signature, lambda t: t.is_sequence()):
    raise TypeError(
        'Applying binary operators in TensorFlow is only '
        'supported on Tensors and StructTypes; you '
        'passed {t} which contains a SequenceType.'.format(t=type_signature))

  def _pack_into_type(to_pack, type_spec):
    """Pack Tensor value `to_pack` into the nested structure `type_spec`."""
    if type_spec.is_struct():
      elem_iter = structure.iter_elements(type_spec)
      return structure.Struct([(elem_name, _pack_into_type(to_pack, elem_type))
                               for elem_name, elem_type in elem_iter])
    elif type_spec.is_tensor():
      return tf.broadcast_to(to_pack, type_spec.shape)

  with tf.Graph().as_default() as graph:
    first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', type_signature[0], graph)
    operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph(
        'y', type_signature[1], graph)
    if type_signature[0].is_equivalent_to(type_signature[1]):
      second_arg = operand_2_value
    else:
      second_arg = _pack_into_type(operand_2_value, type_signature[0])

    if type_signature[0].is_tensor():
      result_value = operator(first_arg, second_arg)
    elif type_signature[0].is_struct():
      result_value = structure.map_structure(operator, first_arg, second_arg)
    else:
      raise TypeError('Encountered unexpected type {t}; can only handle Tensor '
                      'and StructTypes.'.format(t=type_signature[0]))

  result_type, result_binding = tensorflow_utils.capture_result_from_graph(
      result_value, graph)

  type_signature = computation_types.FunctionType(type_signature, result_type)
  parameter_binding = pb.TensorFlow.Binding(
      struct=pb.TensorFlow.StructBinding(
          element=[operand_1_binding, operand_2_binding]))
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
Exemple #6
0
def create_binary_operator(
    operator, operand_type: computation_types.Type) -> ProtoAndType:
  """Returns a tensorflow computation computing a binary operation.

  The returned computation has the type signature `(<T,T> -> U)`, where `T` is
  `operand_type` and `U` is the result of applying the `operator` to a tuple of
  type `<T,T>`

  Note: If `operand_type` is a `computation_types.StructType`, then
  `operator` will be applied pointwise. This places the burden on callers of
  this function to construct the correct values to pass into the returned
  function. For example, to divide `[2, 2]` by `2`, first `2` must be packed
  into the data structure `[x, x]`, before the division operator of the
  appropriate type is called.

  Args:
    operator: A callable taking two arguments representing the operation to
      encode For example: `tf.math.add`, `tf.math.multiply`, and
        `tf.math.divide`.
    operand_type: A `computation_types.Type` to use as the argument to the
      constructed binary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `operand_type` are violated or `operator`
      is not callable.
  """
  if not type_analysis.is_generic_op_compatible_type(operand_type):
    raise TypeError(
        'The type {} contains a type other than `computation_types.TensorType` '
        'and `computation_types.StructType`; this is disallowed in the '
        'generic operators.'.format(operand_type))
  py_typecheck.check_callable(operator)
  with tf.Graph().as_default() as graph:
    operand_1_value, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', operand_type, graph)
    operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph(
        'y', operand_type, graph)

    if operand_type is not None:
      if operand_type.is_tensor():
        result_value = operator(operand_1_value, operand_2_value)
      elif operand_type.is_struct():
        result_value = structure.map_structure(operator, operand_1_value,
                                               operand_2_value)
    else:
      raise TypeError(
          'Operand type {} cannot be used in generic operations. The call to '
          '`type_analysis.is_generic_op_compatible_type` has allowed it to '
          'pass, and should be updated.'.format(operand_type))
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result_value, graph)

  type_signature = computation_types.FunctionType(
      computation_types.StructType((operand_type, operand_type)), result_type)
  parameter_binding = pb.TensorFlow.Binding(
      struct=pb.TensorFlow.StructBinding(
          element=[operand_1_binding, operand_2_binding]))
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
Exemple #7
0
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack):
  """Serializes the 'target' as a TF computation with a given parameter type.

  See also `serialize_tf2_as_tf_computation` for TensorFlow 2
  serialization.

  Args:
    target: The entity to convert into and serialize as a TF computation. This
      can currently only be a Python function. In the future, we will add here
      support for serializing the various kinds of non-eager and eager
      functions, and eventually aim at full support for and compliance with TF
      2.0. This function is currently required to declare either zero parameters
      if `parameter_type` is `None`, or exactly one parameter if it's not
      `None`.  The nested structure of this parameter must correspond to the
      structure of the 'parameter_type'. In the future, we may support targets
      with multiple args/keyword args (to be documented in the API and
      referenced from here).
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `types.Type`, or something that's convertible to it by
      `types.to_type()`.
    context_stack: The context stack to use.

  Returns:
    A tuple of (`pb.Computation`, `tff.Type`), where the computation contains
    the instance with the `pb.TensorFlow` variant set, and the type is an
    instance of `tff.Type`, potentially including Python container annotations,
    for use by TensorFlow computation wrappers.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
  # TODO(b/113112108): Support a greater variety of target type signatures,
  # with keyword args or multiple args corresponding to elements of a tuple.
  # Document all accepted forms with examples in the API, and point to there
  # from here.

  py_typecheck.check_type(target, types.FunctionType)
  py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
  parameter_type = computation_types.to_type(parameter_type)
  argspec = function_utils.get_argspec(target)

  with tf.Graph().as_default() as graph:
    args = []
    if parameter_type is not None:
      if len(argspec.args) != 1:
        raise ValueError(
            'Expected the target to declare exactly one parameter, found {!r}.'
            .format(argspec.args))
      parameter_name = argspec.args[0]
      parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
          parameter_name, parameter_type, graph)
      args.append(parameter_value)
    else:
      if argspec.args:
        raise ValueError(
            'Expected the target to declare no parameters, found {!r}.'.format(
                argspec.args))
      parameter_binding = None
    context = tf_computation_context.TensorFlowComputationContext(graph)
    with context_stack.install(context):
      result = target(*args)

      # TODO(b/122081673): This needs to change for TF 2.0. We may also
      # want to allow the person creating a tff.tf_computation to specify
      # a different initializer; e.g., if it is known that certain
      # variables will be assigned immediately to arguments of the function,
      # then it is wasteful to initialize them before this.
      #
      # The following is a bit of a work around: the collections below may
      # contain variables more than once, hence we throw into a set. TFF needs
      # to ensure all variables are initialized, but not all variables are
      # always in the collections we expect. tff.learning._KerasModel tries to
      # pull Keras variables (that may or may not be in GLOBAL_VARIABLES) into
      # VARS_FOR_TFF_TO_INITIALIZE for now.
      all_variables = set(tf.compat.v1.global_variables() +
                          tf.compat.v1.local_variables() +
                          tf.compat.v1.get_collection(
                              graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE))
      if all_variables:
        # Use a readable but not-too-long name for the init_op.
        name = 'init_op_for_' + '_'.join(
            [v.name.replace(':0', '') for v in all_variables])
        if len(name) > 50:
          name = 'init_op_for_{}_variables'.format(len(all_variables))
        with tf.control_dependencies(context.init_ops):
          # Before running the main new init op, run any initializers for sub-
          # computations from context.init_ops. Variables from import_graph_def
          # will not make it into the global collections, and so will not be
          # initialized without this code path.
          init_op_name = tf.compat.v1.initializers.variables(
              all_variables, name=name).name
      elif context.init_ops:
        init_op_name = tf.group(
            *context.init_ops, name='subcomputation_init_ops').name
      else:
        init_op_name = None

    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result, graph)

  annotated_type = computation_types.FunctionType(parameter_type, result_type)

  return pb.Computation(
      type=pb.Type(
          function=pb.FunctionType(
              parameter=type_serialization.serialize_type(parameter_type),
              result=type_serialization.serialize_type(result_type))),
      tensorflow=pb.TensorFlow(
          graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
          parameter=parameter_binding,
          result=result_binding,
          initialize_op=init_op_name)), annotated_type
def create_binary_operator_with_upcast(
        type_signature: computation_types.StructType,
        operator: Callable[[Any, Any], Any]) -> ComputationProtoAndType:
    """Creates TF computation upcasting its argument and applying `operator`.

  Args:
    type_signature: A `computation_types.StructType` with two elements, both
      only containing structs or tensors in their type tree. The first and
      second element must match in structure, or the second element may be a
      single tensor type that is broadcasted (upcast) to the leaves of the
      structure of the first type.
    operator: Callable defining the operator.

  Returns:
    Same as `create_binary_operator()`.
  """
    py_typecheck.check_type(type_signature, computation_types.StructType)
    py_typecheck.check_callable(operator)
    type_analysis.check_tensorflow_compatible_type(type_signature)
    if not type_signature.is_struct() or len(type_signature) != 2:
        raise TypeError(
            'To apply a binary operator, we must by definition have an '
            'argument which is a `StructType` with 2 elements; '
            'asked to create a binary operator for type: {t}'.format(
                t=type_signature))
    if type_analysis.contains(type_signature, lambda t: t.is_sequence()):
        raise TypeError('Applying binary operators in TensorFlow is only '
                        'supported on Tensors and StructTypes; you '
                        'passed {t} which contains a SequenceType.'.format(
                            t=type_signature))

    def _pack_into_type(to_pack, type_spec):
        """Pack Tensor value `to_pack` into the nested structure `type_spec`."""
        if type_spec.is_struct():
            elem_iter = structure.iter_elements(type_spec)
            return structure.Struct([(elem_name,
                                      _pack_into_type(to_pack, elem_type))
                                     for elem_name, elem_type in elem_iter])
        elif type_spec.is_tensor():
            return tf.broadcast_to(to_pack, type_spec.shape)

    with tf.Graph().as_default() as graph:
        first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', type_signature[0], graph)
        operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph(
            'y', type_signature[1], graph)

        if type_signature[0].is_struct() and type_signature[1].is_struct():
            # If both the first and second arguments are structs with the same
            # structure, simply re-use operand_2_value as. `tf.nest.map_structure`
            # below will map the binary operator pointwise to the leaves of the
            # structure.
            if structure.is_same_structure(type_signature[0],
                                           type_signature[1]):
                second_arg = operand_2_value
            else:
                raise TypeError(
                    'Cannot upcast one structure to a different structure. '
                    '{x} -> {y}'.format(x=type_signature[1],
                                        y=type_signature[0]))
        elif type_signature[0].is_equivalent_to(type_signature[1]):
            second_arg = operand_2_value
        else:
            second_arg = _pack_into_type(operand_2_value, type_signature[0])

        if type_signature[0].is_tensor():
            result_value = operator(first_arg, second_arg)
        elif type_signature[0].is_struct():
            result_value = structure.map_structure(operator, first_arg,
                                                   second_arg)
        else:
            raise TypeError(
                'Encountered unexpected type {t}; can only handle Tensor '
                'and StructTypes.'.format(t=type_signature[0]))

    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result_value, graph)

    type_signature = computation_types.FunctionType(type_signature,
                                                    result_type)
    parameter_binding = pb.TensorFlow.Binding(
        struct=pb.TensorFlow.StructBinding(
            element=[operand_1_binding, operand_2_binding]))
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
Exemple #9
0
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack):
    """Serializes the 'target' as a TF computation with a given parameter type.

  See also `serialize_tf2_as_tf_computation` for TensorFlow 2
  serialization.

  Args:
    target: The entity to convert into and serialize as a TF computation. This
      can currently only be a Python function. In the future, we will add here
      support for serializing the various kinds of non-eager and eager
      functions, and eventually aim at full support for and compliance with TF
      2.0. This function is currently required to declare either zero parameters
      if `parameter_type` is `None`, or exactly one parameter if it's not
      `None`.  The nested structure of this parameter must correspond to the
      structure of the 'parameter_type'. In the future, we may support targets
      with multiple args/keyword args (to be documented in the API and
      referenced from here).
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `types.Type`, or something that's convertible to it by
      `types.to_type()`.
    context_stack: The context stack to use.

  Returns:
    A tuple of (`pb.Computation`, `tff.Type`), where the computation contains
    the instance with the `pb.TensorFlow` variant set, and the type is an
    instance of `tff.Type`, potentially including Python container annotations,
    for use by TensorFlow computation wrappers.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
    # TODO(b/113112108): Support a greater variety of target type signatures,
    # with keyword args or multiple args corresponding to elements of a tuple.
    # Document all accepted forms with examples in the API, and point to there
    # from here.

    py_typecheck.check_type(target, types.FunctionType)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    parameter_type = computation_types.to_type(parameter_type)
    signature = function_utils.get_signature(target)

    with tf.Graph().as_default() as graph:
        if parameter_type is not None:
            if len(signature.parameters) != 1:
                raise ValueError(
                    'Expected the target to declare exactly one parameter, found {!r}.'
                    .format(signature.parameters))
            parameter_name = next(iter(signature.parameters))
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                parameter_name, parameter_type, graph)
        else:
            if signature.parameters:
                raise ValueError(
                    'Expected the target to declare no parameters, found {!r}.'
                    .format(signature.parameters))
            parameter_value = None
            parameter_binding = None
        context = tensorflow_computation_context.TensorFlowComputationContext(
            graph)
        with context_stack.install(context):
            with variable_utils.record_variable_creation_scope(
            ) as all_variables:
                if parameter_value is not None:
                    result = target(parameter_value)
                else:
                    result = target()
            initializer_ops = []
            if all_variables:
                # Use a readable but not-too-long name for the init_op.
                name = 'init_op_for_' + '_'.join(
                    [v.name.replace(':0', '') for v in all_variables])
                if len(name) > 50:
                    name = 'init_op_for_{}_variables'.format(
                        len(all_variables))
                initializer_ops.append(
                    tf.compat.v1.initializers.variables(all_variables,
                                                        name=name))
            initializer_ops.extend(
                tf.compat.v1.get_collection(
                    tf.compat.v1.GraphKeys.TABLE_INITIALIZERS))
            if initializer_ops:
                # Before running the main new init op, run any initializers for sub-
                # computations from context.init_ops. Variables from import_graph_def
                # will not make it into the global collections, and so will not be
                # initialized without this code path.
                with tf.compat.v1.control_dependencies(context.init_ops):
                    init_op_name = tf.group(*initializer_ops,
                                            name='grouped_initializers').name
            elif context.init_ops:
                init_op_name = tf.group(*context.init_ops,
                                        name='subcomputation_init_ops').name
            else:
                init_op_name = None

        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)

    # WARNING: we do not really want to be modifying the graph here if we can
    # avoid it. This is purely to work around performance issues uncovered with
    # the non-standard usage of Tensorflow and have been discussed with the
    # Tensorflow core team before being added.
    clean_graph_def = _clean_graph_def(graph.as_graph_def())
    tensorflow = pb.TensorFlow(
        graph_def=serialization_utils.pack_graph_def(clean_graph_def),
        parameter=parameter_binding,
        result=result_binding,
        initialize_op=init_op_name)
    return pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow), type_signature
def create_binary_operator(operator, operand_type) -> pb.Computation:
    """Returns a tensorflow computation representing the binary `operator`.

  The returned computation has the type signature `(<T,T> -> U)`, where `T` is
  `operand_type` and `U` is the result of applying the `operator` to a tuple of
  type `<T,T>`

  Note: If `operand_type` is a `computation_types.NamedTupleType`, then
  `operator` will be applied pointwise. This places the burden on callers of
  this function to construct the correct values to pass into the returned
  function. For example, to divide `[2, 2]` by `2`, first `2` must be packed
  into the data structure `[x, x]`, before the division operator of the
  appropriate type is called.

  Args:
    operator: A callable taking two arguments representing the operation to
      encode For example: `tf.math.add`, `tf.math.multiply`, and
        `tf.math.divide`.
    operand_type: The type of the argument to the constructed binary operator; A
      type convertible to instance of `computation_types.Type` via
      `computation_types.to_type` which can only contain types which are
      compatible with the TFF generic operators (named tuples and tensors).

  Raises:
    TypeError: If the constraints of `operand_type` are violated or `operator`
      is not callable.
  """
    operand_type = computation_types.to_type(operand_type)
    if not type_utils.is_generic_op_compatible_type(operand_type):
        raise TypeError(
            'The type {} contains a type other than `computation_types.TensorType` '
            'and `computation_types.NamedTupleType`; this is disallowed in the '
            'generic operators.'.format(operand_type))
    py_typecheck.check_callable(operator)
    with tf.Graph().as_default() as graph:
        operand_1_value, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', operand_type, graph)
        operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph(
            'y', operand_type, graph)

        if isinstance(operand_type, computation_types.TensorType):
            result_value = operator(operand_1_value, operand_2_value)
        elif isinstance(operand_type, computation_types.NamedTupleType):
            result_value = anonymous_tuple.map_structure(
                operator, operand_1_value, operand_2_value)
        else:
            raise TypeError(
                'Operand type {} cannot be used in generic operations. The whitelist '
                'in `type_utils.is_generic_op_compatible_type` has allowed it to '
                'pass, and should be updated.'.format(operand_type))
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result_value, graph)

    type_signature = computation_types.FunctionType(
        [operand_type, operand_type], result_type)
    parameter_binding = pb.TensorFlow.Binding(
        tuple=pb.TensorFlow.NamedTupleBinding(
            element=[operand_1_binding, operand_2_binding]))
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)