Ejemplo n.º 1
0
def _create_two_variable_tensorflow():
    with tf.Graph().as_default() as g:
        a = tf.Variable(0, name='variable1')
        b = tf.Variable(1, name='variable2')
        c = a + b

    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        c, g)

    return _pack_noarg_graph(g.as_graph_def(), result_type, result_binding)
def _create_proto_with_unnecessary_op():
    parameter_type = tf.int32

    with tf.Graph().as_default() as graph:
        parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', parameter_type, graph)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            parameter_value, graph)
        unnecessary_op = tf.constant(0)
        tensorflow_utils.capture_result_from_graph(unnecessary_op, graph)

    function_type = computation_types.FunctionType(parameter_type, result_type)
    serialized_function_type = type_serialization.serialize_type(function_type)
    return pb.Computation(type=serialized_function_type,
                          tensorflow=pb.TensorFlow(
                              graph_def=serialization_utils.pack_graph_def(
                                  graph.as_graph_def()),
                              parameter=parameter_binding,
                              result=result_binding))
Ejemplo n.º 3
0
    def test_capture_result_with_attrs_of_constants(self):
        @attr.s
        class TestFoo(object):
            x = attr.ib()
            y = attr.ib()

        graph = tf.compat.v1.get_default_graph()
        type_spec, _ = tensorflow_utils.capture_result_from_graph(
            TestFoo(tf.constant(1), tf.constant(True)), graph)
        self.assertEqual(str(type_spec), '<x=int32,y=bool>')
        self.assertIs(type_spec.python_container, TestFoo)
Ejemplo n.º 4
0
 def test_capture_result_with_ragged_tensor(self):
   with tf.Graph().as_default() as graph:
     type_spec, binding = tensorflow_utils.capture_result_from_graph(
         tf.RaggedTensor.from_row_splits([0, 0, 0, 0], [0, 1, 4]), graph)
     del binding
     self.assert_types_identical(
         type_spec,
         computation_types.StructWithPythonType([
             ('flat_values', computation_types.TensorType(tf.int32, [4])),
             ('nested_row_splits',
              computation_types.StructWithPythonType([
                  (None, computation_types.TensorType(tf.int64, [3]))
              ], tuple)),
         ], tf.RaggedTensor))
Ejemplo n.º 5
0
    def test_capture_result_with_attrs_of_constants(self):
        @attr.s
        class TestFoo(object):
            x = attr.ib()
            y = attr.ib()

        graph = tf.compat.v1.get_default_graph()
        type_spec, _ = tensorflow_utils.capture_result_from_graph(
            TestFoo(tf.constant(1), tf.constant(True)), graph)
        self.assertEqual(str(type_spec), '<x=int32,y=bool>')
        self.assertIsInstance(
            type_spec, computation_types.NamedTupleTypeWithPyContainerType)
        self.assertIs(
            computation_types.NamedTupleTypeWithPyContainerType.
            get_container_type(type_spec), TestFoo)
Ejemplo n.º 6
0
def create_dummy_computation_tensorflow_empty():
    """Returns a tensorflow computation and type `( -> <>)`."""

    with tf.Graph().as_default() as graph:
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            [], graph)

    type_signature = computation_types.FunctionType(None, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
    return value, type_signature
Ejemplo n.º 7
0
def create_empty_tuple() -> ProtoAndType:
  """Returns a tensorflow computation returning an empty tuple.

  The returned computation has the type signature `( -> <>)`.
  """

  with tf.Graph().as_default() as graph:
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        structure.Struct([]), graph)

  type_signature = computation_types.FunctionType(None, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=None,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
def create_empty_tuple() -> pb.Computation:
  """Returns a tensorflow computation returning an empty tuple.

  The returned computation has the type signature `( -> <>)`.
  """

  with tf.Graph().as_default() as graph:
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        [], graph)

  type_signature = computation_types.FunctionType(None, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=None,
      result=result_binding)
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature),
      tensorflow=tensorflow)
Ejemplo n.º 9
0
def create_dummy_computation_tensorflow_identity(type_spec=tf.int32):
    """Returns a tensorflow computation and type `(T -> T)`."""

    with tf.Graph().as_default() as graph:
        parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
            'a', type_spec, graph)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            parameter_value, graph)

    type_signature = computation_types.FunctionType(type_spec, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
    return value, type_signature
  def test_counts_no_variables(self):

    with tf.Graph().as_default() as g:
      a = tf.constant(0)
      b = tf.constant(1)
      c = a + b

    _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

    packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
    function_type = computation_types.FunctionType(None, tf.int32)
    proto = pb.Computation(
        type=type_serialization.serialize_type(function_type),
        tensorflow=pb.TensorFlow(
            graph_def=packed_graph_def, parameter=None, result=result_binding))
    building_block = building_blocks.ComputationBuildingBlock.from_proto(proto)
    tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in(
        building_block)
    self.assertEqual(tf_vars_in_graph, 0)
def create_unary_operator(
        operator,
        operand_type: computation_types.Type) -> ComputationProtoAndType:
    """Returns a tensorflow computation computing a unary operation.

  The returned computation has the type signature `(T -> U)`, where `T` is
  `operand_type` and `U` is the result of applying the `operator` to a value of
  type `T`

  Args:
    operator: A callable taking one argument representing the operation to
      encode For example: `tf.math.abs`.
    operand_type: A `computation_types.Type` to use as the argument to the
      constructed unary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `operand_type` are violated or `operator`
      is not callable.
  """
    if (operand_type is None
            or not type_analysis.is_generic_op_compatible_type(operand_type)):
        raise TypeError(
            '`operand_type` contains a type other than '
            '`computation_types.TensorType` and `computation_types.StructType`; '
            f'this is disallowed in the generic operators. Got: {operand_type} '
        )
    py_typecheck.check_callable(operator)

    with tf.Graph().as_default() as graph:
        operand_value, operand_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', operand_type, graph)
        result_value = operator(operand_value)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result_value, graph)

    type_signature = computation_types.FunctionType(operand_type, result_type)
    parameter_binding = operand_binding
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
  def test_gets_none_placement(self):

    with tf.Graph().as_default() as g:
      a = tf.Variable(0, name='variable1')
      b = tf.Variable(1, name='variable2')
      c = a + b

    _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

    packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
    function_type = computation_types.FunctionType(None, tf.int32)
    proto = pb.Computation(
        type=type_serialization.serialize_type(function_type),
        tensorflow=pb.TensorFlow(
            graph_def=packed_graph_def, parameter=None, result=result_binding))
    building_block = building_blocks.ComputationBuildingBlock.from_proto(proto)
    device_placements = building_block_analysis.get_device_placement_in(
        building_block)
    all_device_placements = list(device_placements.keys())
    self.assertLen(all_device_placements, 1)
    self.assertEqual(all_device_placements[0], '')
    self.assertGreater(device_placements[''], 0)
Ejemplo n.º 13
0
    def test_counts_correct_number_of_ops_simple_case(self):

        with tf.Graph().as_default() as g:
            a = tf.constant(0)
            b = tf.constant(1)
            c = a + b

        _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

        packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
        function_type = computation_types.FunctionType(None, tf.int32)
        proto = pb.Computation(
            type=type_serialization.serialize_type(function_type),
            tensorflow=pb.TensorFlow(graph_def=packed_graph_def,
                                     parameter=None,
                                     result=result_binding))
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in(
            building_block)
        # Expect 4 ops: two constants, one addition, and an identity on the result.
        self.assertEqual(tf_ops_in_graph, 4)
Ejemplo n.º 14
0
def create_dummy_empty_tensorflow_computation():
    """Returns a `pb.Computation` representing an tensorflow graph.

  The type signature of this `pb.Computation` is:

  ( -> <>)

  Returns:
    A `pb.Computation`.
  """

    with tf.Graph().as_default() as graph:
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            [], graph)

    function_type = computation_types.FunctionType(None, result_type)
    type_signature = type_serialization.serialize_type(function_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    return pb.Computation(type=type_signature, tensorflow=tensorflow)
Ejemplo n.º 15
0
def create_dummy_computation_tensorflow_tuple():
    """Returns a tensorflow computation and type.

  `( -> <('a', float32), ('b', float32), ('c', float32)>)`
  """
    value = 10.0

    with tf.Graph().as_default() as graph:
        names = ['a', 'b', 'c']
        result = structure.Struct((n, tf.constant(value)) for n in names)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(None, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
    return value, type_signature
Ejemplo n.º 16
0
def create_identity(type_signature: computation_types.Type) -> ProtoAndType:
  """Returns a tensorflow computation representing an identity function.

  The returned computation has the type signature `(T -> T)`, where `T` is
  `type_signature`. NOTE: if `T` contains `computation_types.StructType`s
  without an associated container type, they will be given the container type
  `tuple` by this function.

  Args:
    type_signature: A `computation_types.Type` to use as the parameter type and
      result type of the identity function.

  Raises:
    TypeError: If `type_signature` contains any types which cannot appear in
      TensorFlow bindings.
  """
  type_analysis.check_tensorflow_compatible_type(type_signature)
  parameter_type = type_signature
  if parameter_type is None:
    raise TypeError('TensorFlow identity cannot be created for NoneType.')

  with tf.Graph().as_default() as graph:
    parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', parameter_type, graph)
    # TF relies on feeds not-identical to fetches in certain circumstances.
    if type_signature.is_tensor():
      parameter_value = tf.identity(parameter_value)
    elif type_signature.is_struct():
      parameter_value = structure.map_structure(tf.identity, parameter_value)
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        parameter_value, graph)

  type_signature = computation_types.FunctionType(parameter_type, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
Ejemplo n.º 17
0
def create_broadcast_scalar_to_shape(scalar_type: tf.DType,
                                     shape: tf.TensorShape) -> pb.Computation:
    """Returns a tensorflow computation returning the result of `tf.broadcast_to`.

  The returned computation has the type signature `(T -> U)`, where
  `T` is `scalar_type` and the `U` is a `tff.TensorType` with a dtype of
  `scalar_type` and a `shape`.

  Args:
    scalar_type: A `tf.DType`, the type of the scalar to broadcast.
    shape: A `tf.TensorShape` to broadcast to. Must be fully defined.

  Raises:
    TypeError: If `scalar_type` is not a `tf.DType` or if `shape` is not a
      `tf.TensorShape`.
    ValueError: If `shape` is not fully defined.
  """
    py_typecheck.check_type(scalar_type, tf.DType)
    py_typecheck.check_type(shape, tf.TensorShape)
    shape.assert_is_fully_defined()
    parameter_type = computation_types.TensorType(scalar_type, shape=())

    with tf.Graph().as_default() as graph:
        parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', parameter_type, graph)
        result = tf.broadcast_to(parameter_value, shape)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
Ejemplo n.º 18
0
  def test_partitioned_call_nodes(self):

    @tf.function
    def test():
      return tf.constant(1)

    with tf.Graph().as_default() as graph:
      result_type, result_binding = tensorflow_utils.capture_result_from_graph(
          test(), graph)

    function_type = computation_types.FunctionType(None, result_type)
    serialized_function_type = type_serialization.serialize_type(function_type)
    proto = pb.Computation(
        type=serialized_function_type,
        tensorflow=pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
            parameter=None,
            result=result_binding))

    self.assertCallOpsGrapplerNotDisabled(proto)
    transformed_proto = tensorflow_computation_transformations.disable_grappler_for_partitioned_calls(
        proto)
    self.assertCallOpsGrapplerDisabled(transformed_proto)
  def test_gets_all_explicit_placement(self):

    with tf.Graph().as_default() as g:
      with tf.device('/cpu:0'):
        a = tf.constant(0)
        b = tf.constant(1)
        c = a + b

    _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

    packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
    function_type = computation_types.FunctionType(None, tf.int32)
    proto = pb.Computation(
        type=type_serialization.serialize_type(function_type),
        tensorflow=pb.TensorFlow(
            graph_def=packed_graph_def, parameter=None, result=result_binding))
    building_block = building_blocks.ComputationBuildingBlock.from_proto(proto)
    device_placements = building_block_analysis.get_device_placement_in(
        building_block)
    all_device_placements = list(device_placements.keys())
    self.assertLen(all_device_placements, 1)
    self.assertIn('CPU', all_device_placements[0])
    self.assertGreater(device_placements[all_device_placements[0]], 0)
Ejemplo n.º 20
0
    def test_valid_ops(self):
        @tf.function
        def test():
            return tf.constant(1)

        with tf.Graph().as_default() as graph:
            result_type, result_binding = tensorflow_utils.capture_result_from_graph(
                test(), graph)

        function_type = computation_types.FunctionType(None, result_type)
        serialized_function_type = type_serialization.serialize_type(
            function_type)
        proto = computation_pb2.Computation(
            type=serialized_function_type,
            tensorflow=computation_pb2.TensorFlow(
                graph_def=serialization_utils.pack_graph_def(
                    graph.as_graph_def()),
                parameter=None,
                result=result_binding))

        disallowed_op_names = frozenset(['ShardedFilename'])
        tensorflow_computation_transformations.check_no_disallowed_ops(
            proto, disallowed_op_names)
Ejemplo n.º 21
0
    def test_counts_correct_number_of_ops_with_function(self):
        @tf.function
        def add_one(x):
            return x + 1

        with tf.Graph().as_default() as graph:
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                'x', tf.int32, graph)
            result = add_one(add_one(parameter_value))

        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)
        type_signature = computation_types.FunctionType(tf.int32, result_type)
        tensorflow = pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
            parameter=parameter_binding,
            result=result_binding)
        proto = pb.Computation(
            type=type_serialization.serialize_type(type_signature),
            tensorflow=tensorflow)
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)

        tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in(
            building_block)

        # Expect 7 ops:
        #    Inside the tf.function:
        #      - one constant
        #      - one addition
        #      - one identity on the result
        #    Inside the tff_computation:
        #      - one placeholders (one for the argument)
        #      - two partition calls
        #      - one identity on the tff_computation result
        self.assertEqual(tf_ops_in_graph, 7)
Ejemplo n.º 22
0
def create_computation_for_py_fn(
        fn: types.FunctionType,
        parameter_type: Optional[computation_types.Type]) -> pb.Computation:
    """Returns a tensorflow computation returning the result of `fn`.

  The returned computation has the type signature `(T -> U)`, where `T` is
  `parameter_type` and `U` is the type returned by `fn`.

  Args:
    fn: A Python function.
    parameter_type: A `computation_types.Type`.
  """
    py_typecheck.check_type(fn, types.FunctionType)
    if parameter_type is not None:
        py_typecheck.check_type(parameter_type, computation_types.Type)

    with tf.Graph().as_default() as graph:
        if parameter_type is not None:
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                'x', parameter_type, graph)
            result = fn(parameter_value)
        else:
            parameter_binding = None
            result = fn()
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
Ejemplo n.º 23
0
 def test_capture_result_with_np_ndarray(self):
   with tf.Graph().as_default() as graph:
     type_spec, binding = tensorflow_utils.capture_result_from_graph(
         np.ndarray(shape=(2, 0), dtype=np.int32), graph)
   self._assert_captured_result_eq_dtype(type_spec, binding, 'int32[2,0]')
Ejemplo n.º 24
0
 def test_capture_result_with_np_bool(self):
   with tf.Graph().as_default() as graph:
     type_spec, binding = tensorflow_utils.capture_result_from_graph(
         np.bool(True), graph)
   self._assert_captured_result_eq_dtype(type_spec, binding, 'bool')
Ejemplo n.º 25
0
 def test_capture_result_with_np_float64(self):
   with tf.Graph().as_default() as graph:
     type_spec, binding = tensorflow_utils.capture_result_from_graph(
         np.float64(1.0), graph)
   self._assert_captured_result_eq_dtype(type_spec, binding, 'float64')
Ejemplo n.º 26
0
 def test_capture_result_with_int(self):
   with tf.Graph().as_default() as graph:
     type_spec, binding = tensorflow_utils.capture_result_from_graph(1, graph)
   self._assert_captured_result_eq_dtype(type_spec, binding, 'int32')
Ejemplo n.º 27
0
def tf_computation_serializer(parameter_type: Optional[computation_types.Type],
                              context_stack):
    """Serializes a TF computation with a given parameter type.

  Args:
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `computation_types.Type`.
    context_stack: The context stack to use.

  Yields:
    The first yielded value will be a Python object (such as a dataset,
    a placeholder, or a `structure.Struct`) to be passed to the function to
    serialize. The result of the function should then be passed to the
    following `send` call.
    The next yielded value will be
    a tuple of (`pb.Computation`, `tff.Type`), where the computation contains
    the instance with the `pb.TensorFlow` variant set, and the type is an
    instance of `tff.Type`, potentially including Python container annotations,
    for use by TensorFlow computation wrappers.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
    # TODO(b/113112108): Support a greater variety of target type signatures,
    # with keyword args or multiple args corresponding to elements of a tuple.
    # Document all accepted forms with examples in the API, and point to there
    # from here.

    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if parameter_type is not None:
        py_typecheck.check_type(parameter_type, computation_types.Type)

    with tf.Graph().as_default() as graph:
        if parameter_type is not None:
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                'arg', parameter_type, graph)
        else:
            parameter_value = None
            parameter_binding = None
        context = tensorflow_computation_context.TensorFlowComputationContext(
            graph)
        with context_stack.install(context):
            with variable_utils.record_variable_creation_scope(
            ) as all_variables:
                result = yield parameter_value
            initializer_ops = []
            if all_variables:
                # Use a readable but not-too-long name for the init_op.
                name = 'init_op_for_' + '_'.join(
                    [v.name.replace(':0', '') for v in all_variables])
                if len(name) > 50:
                    name = 'init_op_for_{}_variables'.format(
                        len(all_variables))
                initializer_ops.append(
                    tf.compat.v1.initializers.variables(all_variables,
                                                        name=name))
            initializer_ops.extend(
                tf.compat.v1.get_collection(
                    tf.compat.v1.GraphKeys.TABLE_INITIALIZERS))
            if initializer_ops:
                # Before running the main new init op, run any initializers for sub-
                # computations from context.init_ops. Variables from import_graph_def
                # will not make it into the global collections, and so will not be
                # initialized without this code path.
                with tf.compat.v1.control_dependencies(context.init_ops):
                    init_op_name = tf.group(*initializer_ops,
                                            name='grouped_initializers').name
            elif context.init_ops:
                init_op_name = tf.group(*context.init_ops,
                                        name='subcomputation_init_ops').name
            else:
                init_op_name = None

        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)

    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding,
                               initialize_op=init_op_name)
    yield pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow), type_signature
Ejemplo n.º 28
0
def create_constant(scalar_value,
                    type_spec: computation_types.Type) -> ProtoAndType:
  """Returns a tensorflow computation returning a constant `scalar_value`.

  The returned computation has the type signature `( -> T)`, where `T` is
  `type_spec`.

  `scalar_value` must be a scalar, and cannot be a float if any of the tensor
  leaves of `type_spec` contain an integer data type. `type_spec` must contain
  only named tuples and tensor types, but these can be arbitrarily nested.

  Args:
    scalar_value: A scalar value to place in all the tensor leaves of
      `type_spec`.
    type_spec: A `computation_types.Type` to use as the argument to the
      constructed binary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `type_spec` are violated.
  """
  if not type_analysis.is_generic_op_compatible_type(type_spec):
    raise TypeError(
        'Type spec {} cannot be constructed as a TensorFlow constant in TFF; '
        ' only nested tuples and tensors are permitted.'.format(type_spec))
  inferred_scalar_value_type = type_conversions.infer_type(scalar_value)
  if (not inferred_scalar_value_type.is_tensor() or
      inferred_scalar_value_type.shape != tf.TensorShape(())):
    raise TypeError(
        'Must pass a scalar value to `create_tensorflow_constant`; encountered '
        'a value {}'.format(scalar_value))
  tensor_dtypes_in_type_spec = []

  def _pack_dtypes(type_signature):
    """Appends dtype of `type_signature` to nonlocal variable."""
    if type_signature.is_tensor():
      tensor_dtypes_in_type_spec.append(type_signature.dtype)
    return type_signature, False

  type_transformations.transform_type_postorder(type_spec, _pack_dtypes)

  if (any(x.is_integer for x in tensor_dtypes_in_type_spec) and
      not inferred_scalar_value_type.dtype.is_integer):
    raise TypeError(
        'Only integers can be used as scalar values if our desired constant '
        'type spec contains any integer tensors; passed scalar {} of dtype {} '
        'for type spec {}.'.format(scalar_value,
                                   inferred_scalar_value_type.dtype, type_spec))

  result_type = type_spec

  def _create_result_tensor(type_spec, scalar_value):
    """Packs `scalar_value` into `type_spec` recursively."""
    if type_spec.is_tensor():
      type_spec.shape.assert_is_fully_defined()
      result = tf.constant(
          scalar_value, dtype=type_spec.dtype, shape=type_spec.shape)
    else:
      elements = []
      for _, type_element in structure.iter_elements(type_spec):
        elements.append(_create_result_tensor(type_element, scalar_value))
      result = elements
    return result

  with tf.Graph().as_default() as graph:
    result = _create_result_tensor(result_type, scalar_value)
    _, result_binding = tensorflow_utils.capture_result_from_graph(
        result, graph)

  type_signature = computation_types.FunctionType(None, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=None,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
Ejemplo n.º 29
0
def create_binary_operator_with_upcast(
    type_signature: computation_types.StructType,
    operator: Callable[[Any, Any], Any]) -> ProtoAndType:
  """Creates TF computation upcasting its argument and applying `operator`.

  Args:
    type_signature: A `computation_types.StructType` with two elements, both of
      the same type or the second able to be upcast to the first, as explained
      in `apply_binary_operator_with_upcast`, and both containing only tuples
      and tensors in their type tree.
    operator: Callable defining the operator.

  Returns:
    A `building_blocks.CompiledComputation` encapsulating a function which
    upcasts the second element of its argument and applies the binary
    operator.
  """
  py_typecheck.check_type(type_signature, computation_types.StructType)
  py_typecheck.check_callable(operator)
  type_analysis.check_tensorflow_compatible_type(type_signature)
  if not type_signature.is_struct() or len(type_signature) != 2:
    raise TypeError('To apply a binary operator, we must by definition have an '
                    'argument which is a `StructType` with 2 elements; '
                    'asked to create a binary operator for type: {t}'.format(
                        t=type_signature))
  if type_analysis.contains(type_signature, lambda t: t.is_sequence()):
    raise TypeError(
        'Applying binary operators in TensorFlow is only '
        'supported on Tensors and StructTypes; you '
        'passed {t} which contains a SequenceType.'.format(t=type_signature))

  def _pack_into_type(to_pack, type_spec):
    """Pack Tensor value `to_pack` into the nested structure `type_spec`."""
    if type_spec.is_struct():
      elem_iter = structure.iter_elements(type_spec)
      return structure.Struct([(elem_name, _pack_into_type(to_pack, elem_type))
                               for elem_name, elem_type in elem_iter])
    elif type_spec.is_tensor():
      return tf.broadcast_to(to_pack, type_spec.shape)

  with tf.Graph().as_default() as graph:
    first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', type_signature[0], graph)
    operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph(
        'y', type_signature[1], graph)
    if type_signature[0].is_equivalent_to(type_signature[1]):
      second_arg = operand_2_value
    else:
      second_arg = _pack_into_type(operand_2_value, type_signature[0])

    if type_signature[0].is_tensor():
      result_value = operator(first_arg, second_arg)
    elif type_signature[0].is_struct():
      result_value = structure.map_structure(operator, first_arg, second_arg)
    else:
      raise TypeError('Encountered unexpected type {t}; can only handle Tensor '
                      'and StructTypes.'.format(t=type_signature[0]))

  result_type, result_binding = tensorflow_utils.capture_result_from_graph(
      result_value, graph)

  type_signature = computation_types.FunctionType(type_signature, result_type)
  parameter_binding = pb.TensorFlow.Binding(
      struct=pb.TensorFlow.StructBinding(
          element=[operand_1_binding, operand_2_binding]))
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)