Esempio n. 1
0
def make_dataset_from_variant_tensor(variant_tensor, type_spec):
    """Constructs a `tf.data.Dataset` from a variant tensor and type spec.

  Args:
    variant_tensor: The variant tensor that represents the dataset.
    type_spec: The type spec of elements of the data set, either an instance of
      `types.Type` or something convertible to it.

  Returns:
    A corresponding instance of `tf.data.Dataset`.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
    if not tf.contrib.framework.is_tensor(variant_tensor):
        raise TypeError(
            'Expected `variant_tensor` to be a tensor, found {}.'.format(
                py_typecheck.type_string(type(variant_tensor))))
    if variant_tensor.dtype != tf.variant:
        raise TypeError(
            'Expected `variant_tensor` to be of a variant type, found {}.'.
            format(str(variant_tensor.dtype)))
    return tf.data.experimental.from_variant(
        variant_tensor,
        structure=(type_utils.type_to_tf_structure(
            computation_types.to_type(type_spec))))
Esempio n. 2
0
 def test_type_to_tf_structure_without_names(self):
   type_spec = computation_types.to_type((tf.bool, tf.int32))
   dtypes, shapes = type_utils.type_to_tf_dtypes_and_shapes(type_spec)
   structure = type_utils.type_to_tf_structure(type_spec)
   with tf.Graph().as_default():
     ds = tf.data.experimental.from_variant(
         tf.placeholder(tf.variant, shape=[]), structure=structure)
     ds_dtypes = tf.compat.v1.data.get_output_types(ds)
     ds_shapes = tf.compat.v1.data.get_output_shapes(ds)
     test.assert_nested_struct_eq(ds_dtypes, dtypes)
     test.assert_nested_struct_eq(ds_shapes, shapes)
Esempio n. 3
0
 def test_type_to_tf_structure_with_names(self):
   type_spec = computation_types.to_type(
       collections.OrderedDict([
           ('a', tf.bool),
           ('b',
            collections.OrderedDict([
                ('c', tf.float32),
                ('d', (tf.int32, [20])),
            ])),
       ]))
   dtypes, shapes = type_utils.type_to_tf_dtypes_and_shapes(type_spec)
   structure = type_utils.type_to_tf_structure(type_spec)
   with tf.Graph().as_default():
     ds = tf.data.experimental.from_variant(
         tf.placeholder(tf.variant, shape=[]), structure=structure)
     ds_dtypes = tf.compat.v1.data.get_output_types(ds)
     ds_shapes = tf.compat.v1.data.get_output_shapes(ds)
     test.assert_nested_struct_eq(ds_dtypes, dtypes)
     test.assert_nested_struct_eq(ds_shapes, shapes)
Esempio n. 4
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
    """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional device name.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    # TODO(b/134543154): Decide whether this belongs in `graph_utils.py` since
    # it deals exclusively with eager mode. Incubate here, and potentially move
    # there, once stable.

    if device is not None:
        raise NotImplementedError(
            'Unable to embed TF code on a specific device.')

    py_typecheck.check_type(comp, pb.Computation)
    comp_type = type_serialization.deserialize_type(comp.type)
    type_spec = computation_types.to_type(type_spec)
    if type_spec is not None:
        if not type_utils.are_equivalent_types(type_spec, comp_type):
            raise TypeError(
                'Expected a computation of type {}, got {}.'.format(
                    str(type_spec), str(comp_type)))
    else:
        type_spec = comp_type
    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    if isinstance(type_spec, computation_types.FunctionType):
        param_type = type_spec.parameter
        result_type = type_spec.result
    else:
        param_type = None
        result_type = type_spec

    if param_type is not None:
        input_tensor_names = graph_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    output_tensor_names = graph_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)

    def function_to_wrap(*args):  # pylint: disable=missing-docstring
        if len(args) != len(input_tensor_names):
            raise RuntimeError('Expected {} arguments, found {}.'.format(
                str(len(input_tensor_names)), str(len(args))))
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        init_op = comp.tensorflow.initialize_op
        init_names = [init_op] if init_op else []
        returned_elements = tf.import_graph_def(
            graph_merge.uniquify_shared_names(graph_def),
            input_map=dict(zip(input_tensor_names, args)),
            return_elements=output_tensor_names + init_names)
        if init_names:
            with tf.control_dependencies([returned_elements[-1]]):
                return [tf.identity(x) for x in returned_elements[0:-1]]
        else:
            return returned_elements

    signature = []
    param_fns = []
    if param_type is not None:
        for spec in anonymous_tuple.flatten(type_spec.parameter):
            if isinstance(spec, computation_types.TensorType):
                signature.append(tf.TensorSpec(spec.shape, spec.dtype))
                param_fns.append(lambda x: x)
            else:
                py_typecheck.check_type(spec, computation_types.SequenceType)
                signature.append(tf.TensorSpec([], tf.variant))
                param_fns.append(tf.data.experimental.to_variant)

    wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature)

    result_fns = []
    for spec in anonymous_tuple.flatten(result_type):
        if isinstance(spec, computation_types.TensorType):
            result_fns.append(lambda x: x)
        else:
            py_typecheck.check_type(spec, computation_types.SequenceType)
            structure = type_utils.type_to_tf_structure(spec.element)

            def fn(x, structure=structure):
                return tf.data.experimental.from_variant(x, structure)

            result_fns.append(fn)

    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
        param_elements = []
        if arg is not None:
            arg_parts = anonymous_tuple.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    str(len(param_fns)), str(len(arg_parts))))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)
        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return anonymous_tuple.pack_sequence_as(result_type, result_elements)

    fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(
        arg, p, w)
    if param_type is not None:
        return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
    else:
        return lambda: fn_to_return(None)
Esempio n. 5
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
    """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional device name.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
    # since it deals exclusively with eager mode. Incubate here, and potentially
    # move there, once stable.

    py_typecheck.check_type(comp, pb.Computation)
    comp_type = type_serialization.deserialize_type(comp.type)
    type_spec = computation_types.to_type(type_spec)
    if type_spec is not None:
        if not type_utils.are_equivalent_types(type_spec, comp_type):
            raise TypeError(
                'Expected a computation of type {}, got {}.'.format(
                    type_spec, comp_type))
    else:
        type_spec = comp_type
    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    if isinstance(type_spec, computation_types.FunctionType):
        param_type = type_spec.parameter
        result_type = type_spec.result
    else:
        param_type = None
        result_type = type_spec

    if param_type is not None:
        input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)

    def function_to_wrap(*args):  # pylint: disable=missing-docstring
        if len(args) != len(input_tensor_names):
            raise RuntimeError('Expected {} arguments, found {}.'.format(
                len(input_tensor_names), len(args)))
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        init_op = comp.tensorflow.initialize_op
        if init_op:
            graph_def = tensorflow_utils.add_control_deps_for_init_op(
                graph_def, init_op)

        def _import_fn():
            return tf.import_graph_def(
                graph_merge.uniquify_shared_names(graph_def),
                input_map=dict(list(zip(input_tensor_names, args))),
                return_elements=output_tensor_names)

        if device is not None:
            with tf.device(device):
                return _import_fn()
        else:
            return _import_fn()

    signature = []
    param_fns = []
    if param_type is not None:
        for spec in anonymous_tuple.flatten(type_spec.parameter):
            if isinstance(spec, computation_types.TensorType):
                signature.append(tf.TensorSpec(spec.shape, spec.dtype))
                param_fns.append(lambda x: x)
            else:
                py_typecheck.check_type(spec, computation_types.SequenceType)
                signature.append(tf.TensorSpec([], tf.variant))
                param_fns.append(tf.data.experimental.to_variant)

    wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature)

    result_fns = []
    for spec in anonymous_tuple.flatten(result_type):
        if isinstance(spec, computation_types.TensorType):
            result_fns.append(lambda x: x)
        else:
            py_typecheck.check_type(spec, computation_types.SequenceType)
            structure = type_utils.type_to_tf_structure(spec.element)

            def fn(x, structure=structure):
                return tf.data.experimental.from_variant(x, structure)

            result_fns.append(fn)

    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
        param_elements = []
        if arg is not None:
            arg_parts = anonymous_tuple.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    len(param_fns), len(arg_parts)))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)

        # There is a tf.wrap_function(...) issue b/144127474 that variables created
        # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
        # destroyed.  So get all the variables from `wrapped_fn` and destroy
        # manually.
        # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
        # is fixed.
        resources = []
        for op in wrapped_fn.graph.get_operations():
            if op.type == 'VarHandleOp':
                resources += op.outputs
        if resources:
            for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
                tf.raw_ops.DestroyResourceOp(resource=resource)

        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return anonymous_tuple.pack_sequence_as(result_type, result_elements)

    fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(
        arg, p, w)

    if device is not None:
        old_fn_to_return = fn_to_return

        # pylint: disable=function-redefined
        def fn_to_return(x):
            with tf.device(device):
                return old_fn_to_return(x)

        # pylint: enable=function-redefined

    if param_type is not None:
        return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
    else:
        return lambda: fn_to_return(None)
Esempio n. 6
0
 def test_type_to_tf_structure_with_no_elements(self):
     with self.assertRaises(ValueError):
         type_utils.type_to_tf_structure(
             computation_types.NamedTupleType([]))
Esempio n. 7
0
 def test_type_to_tf_structure_with_inconsistently_named_elements(self):
     with self.assertRaises(ValueError):
         type_utils.type_to_tf_structure(
             computation_types.NamedTupleType([('a', tf.int32), tf.bool]))
Esempio n. 8
0
 def test_type_to_tf_structure_with_sequence_type(self):
     with self.assertRaises(ValueError):
         type_utils.type_to_tf_structure(
             computation_types.SequenceType(tf.int32))
Esempio n. 9
0
 def test_type_to_tf_structure_with_none(self):
     with self.assertRaises(ValueError):
         type_utils.type_to_tf_structure(None)