Exemplo n.º 1
0
    import wrapt
except ImportError:
    # Fall back to the build-time dependency if the system package is not available.
    from .....third_party import wrapt  # pylint: disable=relative-beyond-top-level

from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import revived_types
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import layer_utils
from tensorflow.python.util import lazy_loader
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.tf_export import tf_export

module = lazy_loader.LazyLoader("module", globals(),
                                "tensorflow.python.module.module")


class NoDependency(object):
    """Allows attribute assignment to `Trackable` objects with no dependency.

  Example usage:
  ```python
  obj = Trackable()
  obj.has_dependency = tf.Variable(0., name="dep")
  obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
  assert obj.no_dependency.name == "nodep:0"
  ```

  `obj` in this example has a dependency on the variable "dep", and both
  attributes contain un-wrapped `Variable` objects.
Exemplo n.º 2
0
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.saver import export_meta_graph
from tensorflow.python.util import lazy_loader
from tensorflow.python.util import object_identity

# Lazy load the single eager module to avoid introducing new dependencies for
# graph_util:convert_variables_to_constants (eg in
# tensorflow/contrib/session_bundle:session_bundle_py_test).
wrap_function = lazy_loader.LazyLoader(
    "wrap_function", globals(),
    "tensorflow.python.eager.wrap_function")

# Used in _FunctionConverterDataInGraph().
VAR_ASSIGN_COLLECTION = "extra_var_assign_ops"
_CONDITIONAL_OPS = set(["If", "StatelessIf"])
_LOOP_OPS = set(["While", "StatelessWhile"])
_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(_LOOP_OPS)


class _TensorData(
    collections.namedtuple("_TensorData", ["numpy", "dtype", "index"])):
  """Data about a tensor that was converted to a constant."""
  __slots__ = ()

  @property
Exemplo n.º 3
0
    "the loop.")

# NOTE(jsimsa): Threshold used as a heuristic to check for infinite loop during
# tf.function tracing.
GET_NEXT_CALL_ERROR_THRESHOLD = 32

GET_NEXT_CALL_ERROR_MESSAGE = (
    "An unusually high number of `tf.data.Iterator.get_next()` calls was "
    "detected. This suggests that the `for elem in dataset: ...` idiom is used "
    "within tf.function with AutoGraph disabled. This idiom is only supported "
    "when AutoGraph is enabled.")

# Collection of all IteratorResources in the `Graph`.
GLOBAL_ITERATORS = "iterators"

autograph_ctx = lazy_loader.LazyLoader(
    "autograph_ctx", globals(), "tensorflow.python.autograph.core.ag_ctx")

# Avoid circular dependency for `type_utils` which transitively depends
# on Autograph which in turn depends on tf.data.
type_utils = lazy_loader.LazyLoader("type_utils", globals(),
                                    "tensorflow.python.framework.type_utils")


def _device_stack_is_empty():
    if context.executing_eagerly():
        return context.context().device_name is None
    # pylint: disable=protected-access
    device_stack = ops.get_default_graph()._device_functions_outer_to_inner
    # pylint: enable=protected-access
    return not bool(device_stack)
Exemplo n.º 4
0
from __future__ import division
from __future__ import print_function

import numpy as onp
import jax.numpy as np

from tensorflow_probability.python.internal.backend.jax import dtype as dtypes
from tensorflow_probability.python.internal.backend.jax import ops
from tensorflow_probability.python.internal.backend.jax import ops as tensor_shape
from tensorflow_probability.python.internal.backend.jax import numpy_array as array_ops
from tensorflow_probability.python.internal.backend.jax import v1 as check_ops
from tensorflow_probability.python.internal.backend.jax import numpy_math as math_ops

from tensorflow.python.util import lazy_loader
distribution_util = lazy_loader.LazyLoader(
    "distribution_util", globals(),
    "tensorflow_probability.python.experimental.substrates.numpy.internal."
    "distribution_util")

from tensorflow_probability.python.internal.backend.jax import linalg_impl as linalg
from tensorflow_probability.python.internal.backend.jax import linear_operator
from tensorflow_probability.python.internal.backend.jax import linear_operator_util
from tensorflow_probability.python.internal.backend.jax import numpy_signal as fft_ops
# from tensorflow.python.util.tf_export import tf_export

__all__ = [
    "LinearOperatorCirculant",
    "LinearOperatorCirculant2D",
    "LinearOperatorCirculant3D",
]

# Different FFT Ops will be used for different block depths.
Exemplo n.º 5
0
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.lib.core import _pywrap_py_func
from tensorflow.python.ops import gen_script_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import lazy_loader
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export

autograph = lazy_loader.LazyLoader("autograph", globals(),
                                   "tensorflow.python.autograph.impl.api")

# Map from EagerPyFunc token to tuple (tape, eager args, eager outputs);
# used for differentiation.
tape_cache = {}


def _maybe_copy_to_context_device(tensor, device_name):
    """Copy an EagerTensor to the current device if it's not on `device_name`."""
    in_device = tensor.backing_device
    if device_name == in_device:
        return tensor
    else:
        # Note that EagerTensor._copy bypasses the placer and copies to the context
        # device, which means e.g. int32 Tensors which would normally be forced onto
        # the CPU can instead be placed on the GPU. This is necessary so that the
            # Mod out by the total number of elements to ensure the index is
            # non-negative (for tf.gather) and < 2 * n - 1.
            2 * n - 1)
        return array_ops.gather(elements, indices, axis=-1)

    @property
    def col(self):
        return self._col

    @property
    def row(self):
        return self._row


def _to_complex(x):
    dtype = dtypes.complex64
    if x.dtype in [dtypes.float64, dtypes.complex128]:
        dtype = dtypes.complex128
    return _ops.cast(x, dtype)


import numpy as np
from tensorflow_probability.python.internal.backend.numpy import linalg_impl as _linalg
from tensorflow_probability.python.internal.backend.numpy import ops as _ops
from tensorflow_probability.python.internal.backend.numpy.gen import tensor_shape

from tensorflow.python.util import lazy_loader
distribution_util = lazy_loader.LazyLoader(
    "distribution_util", globals(),
    "tensorflow_probability.python.internal._numpy.distribution_util")
import six

from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as eager_function

from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import script_ops
from tensorflow.python.util import function_utils
from tensorflow.python.util import lazy_loader

autograph = lazy_loader.LazyLoader("autograph", globals(),
                                   "tensorflow.python.autograph.impl.api")
# TODO(mdan): Create a public API for this.
autograph_ctx = lazy_loader.LazyLoader(
    "autograph_ctx", globals(), "tensorflow.python.autograph.core.ag_ctx")
dataset_ops = lazy_loader.LazyLoader("dataset_ops", globals(),
                                     "tensorflow.python.data.ops.dataset_ops")


def _should_pack(arg):
    """Determines whether the caller needs to pack the argument in a tuple.

  If user-defined function returns a list of tensors, `nest.flatten()` and
  `ops.convert_to_tensor()` and would conspire to attempt to stack those tensors
  into a single tensor because the tf.data version of `nest.flatten()` does
  not recurse into lists. Since it is more likely that the list arose from
  returning the result of an operation (such as `tf.numpy_function()`) that
Exemplo n.º 8
0
import six

from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import _proto_comparators
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.util import deprecation
from tensorflow.python.util import lazy_loader
from tensorflow.python.util.tf_export import tf_export

tf_export(v1=["GraphDef"])(graph_pb2.GraphDef)

# A normal import here would generate circular dependencies.
convert_to_constants = lazy_loader.LazyLoader(
    "convert_to_constants", globals(),
    "tensorflow.python.framework.convert_to_constants")

_VARIABLE_OPS = {
    "Assign",
    "AssignAdd",
    "AssignSub",
    "Queue",
    "ScatterAdd",
    "ScatterSub",
    "ScatterUpdate",
    "TruncatedNormal",
    "Variable",
    "VariableV2",
}
Exemplo n.º 9
0
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Utilities for Keras classes with v1 and v2 versions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import ops
from tensorflow.python.util import lazy_loader

# TODO(b/134426265): Switch back to single-quotes once the issue
# with copybara is fixed.
# pylint: disable=g-inconsistent-quotes
training = lazy_loader.LazyLoader("training", globals(),
                                  "tensorflow.python.keras.engine.training")
training_v1 = lazy_loader.LazyLoader(
    "training_v1", globals(), "tensorflow.python.keras.engine.training_v1")
base_layer = lazy_loader.LazyLoader(
    "base_layer", globals(), "tensorflow.python.keras.engine.base_layer")
base_layer_v1 = lazy_loader.LazyLoader(
    "base_layer_v1", globals(), "tensorflow.python.keras.engine.base_layer_v1")
callbacks = lazy_loader.LazyLoader("callbacks", globals(),
                                   "tensorflow.python.keras.callbacks")
callbacks_v1 = lazy_loader.LazyLoader("callbacks_v1", globals(),
                                      "tensorflow.python.keras.callbacks_v1")

# pylint: enable=g-inconsistent-quotes


class ModelVersionSelector(object):
Exemplo n.º 10
0
from tensorflow_probability.python import glm
from tensorflow_probability.python import layers
from tensorflow_probability.python import math
from tensorflow_probability.python import mcmc
from tensorflow_probability.python import monte_carlo
from tensorflow_probability.python import optimizer
from tensorflow_probability.python import random
from tensorflow_probability.python import stats
from tensorflow_probability.python import sts
from tensorflow_probability.python import util
from tensorflow_probability.python import vi

from tensorflow.python.util import lazy_loader  # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.util.all_util import remove_undocumented  # pylint: disable=g-direct-tensorflow-import

edward2 = lazy_loader.LazyLoader('edward2', globals(),
                                 'tensorflow_probability.python.edward2')

_allowed_symbols = [
    'bijectors',
    'debugging',
    'distributions',
    'edward2',
    'experimental',
    'glm',
    'layers',
    'math',
    'mcmc',
    'monte_carlo',
    'optimizer',
    'random',
    'stats',
Exemplo n.º 11
0
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Utilities for Keras classes with v1 and v2 versions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import ops
from tensorflow.python.util import lazy_loader

# TODO(b/134426265): Switch back to single-quotes once the issue
# with copybara is fixed.
# pylint: disable=g-inconsistent-quotes
training = lazy_loader.LazyLoader("training", globals(),
                                  "tensorflow.python.keras.engine.training")
training_v1 = lazy_loader.LazyLoader(
    "training_v1", globals(), "tensorflow.python.keras.engine.training_v1")

# pylint: enable=g-inconsistent-quotes


# TODO(omalleyt): Extend to Layer class once Layer class is split.
class VersionSelector(object):
    """Chooses between Keras v1 and v2 Model class."""
    def __new__(cls, *args, **kwargs):  # pylint: disable=unused-argument
        new_cls = swap_class(cls, training.Model, training_v1.Model)
        return object.__new__(new_cls)


def swap_class(cls, v2_cls, v1_cls):
Exemplo n.º 12
0
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Contains the normalization layer classes and their functional aliases.
"""

from tensorflow.python.util import lazy_loader

normalization = lazy_loader.LazyLoader('normalization', globals(),
                                       'keras.legacy_tf_layers.normalization')


# pylint: disable=invalid-name
# lazy load all the attributes until they are accessed for the first time
def __getattr__(name):
    if name in ['BatchNormalization', 'BatchNorm']:
        return normalization.BatchNormalization
    elif name in ['batch_normalization', 'batch_norm']:
        return normalization.batch_normalization
    else:
        raise AttributeError(
            f'module {__name__} doesn\'t have attribute {name}')
Exemplo n.º 13
0
    This method is called on all nodes in the Trackable Graph (generated by
    `_trackable_children`). The nodes are traversed in the order defined by
    `_deserialization_dependencies`

    All usages of _map_resources should be migrated to this method.

    Args:
      object_map: A dictionary that maps original Trackables to the copied
        Trackables. This only needs to be updated if the object is a
        tf.function, or if the copied tensors are necessary for checkpointing
        this object.
      tensor_map: Dictionary mapping original tensors to copied tensors.
      options: A `tf.saved_model.SaveOptions` object.
      **kwargs: Additional kwargs that may be added at a later time.

    Returns:
      Flat list of original tensors that have been copied.
    """
        del kwargs  # Unused.
        self_object_map, self_tensor_map = self._map_resources(options)
        object_map.update(self_object_map)
        tensor_map.update(self_tensor_map)
        return list(self_tensor_map.keys())


# TODO(kathywu): Delete the imports below once dependencies have been migrated.
python_state = lazy_loader.LazyLoader(
    "python_state", globals(), "tensorflow.python.trackable.python_state")
PythonStateSaveable = python_state.PythonStateSaveable
Exemplo n.º 14
0
    "the loop.")

# NOTE(jsimsa): Threshold used as a heuristic to check for infinite loop during
# tf.function tracing.
GET_NEXT_CALL_ERROR_THRESHOLD = 32

GET_NEXT_CALL_ERROR_MESSAGE = (
    "An unusually high number of `tf.data.Iterator.get_next()` calls was "
    "detected. This suggests that the `for elem in dataset: ...` idiom is used "
    "within tf.function with AutoGraph disabled. This idiom is only supported "
    "when AutoGraph is enabled.")

# Collection of all IteratorResources in the `Graph`.
GLOBAL_ITERATORS = "iterators"

autograph_ctx = lazy_loader.LazyLoader(
    "autograph_ctx", globals(), "tensorflow.python.autograph.core.ag_ctx")


def _device_stack_is_empty():
    if context.executing_eagerly():
        return context.context().device_name is None
    # pylint: disable=protected-access
    device_stack = ops.get_default_graph()._device_functions_outer_to_inner
    # pylint: enable=protected-access
    return not bool(device_stack)


@tf_export(v1=["data.Iterator"])
class Iterator(trackable.Trackable):
    """Represents the state of iterating through a `Dataset`."""
    def __init__(self, iterator_resource, initializer, output_types,
Exemplo n.º 15
0
from tensorflow.python.training import saver as v1_saver_lib
from tensorflow.python.training.saving import functional_saver
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import graph_view as graph_view_lib
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import lazy_loader
from tensorflow.python.util import object_identity
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export

# Loaded lazily due to a circular dependency.
keras_backend = lazy_loader.LazyLoader("keras_backend", globals(),
                                       "tensorflow.python.keras.backend")


def get_session():
    # Prefer TF's default session since get_session from Keras has side-effects.
    session = ops.get_default_session()
    if session is None:
        session = keras_backend.get_session()
    return session


class Checkpoint(tf.train.Checkpoint):
    def save(self, file_prefix):
        """Saves a training checkpoint and provides basic checkpoint management.
        The saved checkpoint includes variables created by this object and any
        trackable objects it depends on at the time `Checkpoint.save()` is
Exemplo n.º 16
0
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util import lazy_loader
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export

np_arrays = lazy_loader.LazyLoader(
    "np_arrays", globals(), "tensorflow.python.ops.numpy_ops.np_arrays")


@tf_export(v1=["map_fn"])
@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
def map_fn(fn,
           elems,
           dtype=None,
           parallel_iterations=None,
           back_prop=True,
           swap_memory=False,
           infer_shape=True,
           name=None,
           fn_output_signature=None):
    """Transforms `elems` by applying `fn` to each element unstacked on axis 0.
Exemplo n.º 17
0
from tensorflow_probability.python import experimental
from tensorflow_probability.python import glm
from tensorflow_probability.python import layers
from tensorflow_probability.python import math
from tensorflow_probability.python import mcmc
from tensorflow_probability.python import monte_carlo
from tensorflow_probability.python import optimizer
from tensorflow_probability.python import stats
from tensorflow_probability.python import sts
from tensorflow_probability.python import util
from tensorflow_probability.python import vi

from tensorflow.python.util import lazy_loader  # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.util.all_util import remove_undocumented  # pylint: disable=g-direct-tensorflow-import

edward2 = lazy_loader.LazyLoader('edward2', globals(),
                                 'tensorflow_probability.python.edward2')
positive_semidefinite_kernels = lazy_loader.LazyLoader(
    'positive_semidefinite_kernels', globals(),
    'tensorflow_probability.python.positive_semidefinite_kernels')

_allowed_symbols = [
    'bijectors',
    'debugging',
    'distributions',
    'edward2',
    'experimental',
    'glm',
    'layers',
    'math',
    'mcmc',
    'monte_carlo',
Exemplo n.º 18
0
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_parsing_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sort_ops
from tensorflow.python.util import lazy_loader
from tensorflow.python.util import nest


# TODO(b/145618471): Remove this dependency.
# Lazy import to work around circular dependencies
input_lib = lazy_loader.LazyLoader(
    'input_lib', globals(),
    'tensorflow.python.distribute.input_lib')
parallel_ops = lazy_loader.LazyLoader(
    'parallel_ops', globals(),
    'tensorflow.python.ops.parallel_for.control_flow_ops')


UNSPECIFIED = object()


def overload_of(f):
  if f in SUPPORTED_BUILTINS:
    return BUILTIN_FUNCTIONS_MAP[f.__name__]
  return f

# limitations under the License.
# ==============================================================================
"""Registry for tensor conversion functions."""
# pylint: disable=g-bad-name
import collections
import threading

import numpy as np
import six

from tensorflow.python.util import lazy_loader
from tensorflow.python.util.tf_export import tf_export

# Loaded lazily due to a circular dependency
# ops->tensor_conversion_registry->constant_op->ops.
constant_op = lazy_loader.LazyLoader(
    "constant_op", globals(), "tensorflow.python.framework.constant_op")

_tensor_conversion_func_registry = collections.defaultdict(list)
_tensor_conversion_func_cache = {}
_tensor_conversion_func_lock = threading.Lock()

# Instances of these types are always converted using
# `_default_conversion_function`.
_UNCONVERTIBLE_TYPES = six.integer_types + (
    float,
    np.generic,
    np.ndarray,
)


def _default_conversion_function(value, dtype, name, as_ref):
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import lazy_loader
from tensorflow.python.util import nest

# TODO(b/145618471): Remove this dependency.
# Lazy import to work around circular dependencies
input_lib = lazy_loader.LazyLoader('input_lib', globals(),
                                   'tensorflow.python.distribute.input_lib')

LIMIT_PYTHON_ITERATIONS = True
PYTHON_MAX_ITERATIONS = 100000000  # Fails in about one minute for empty loops.
WARN_INEFFICIENT_UNROLL = True
INEFFICIENT_UNROLL_MIN_ITERATIONS = 3000
INEFFICIENT_UNROLL_MIN_OPS = 1


def _disallow_undefs_into_loop(*values):
    """Ensures that all values in the state are defined when entering a loop."""
    undefined = tuple(filter(special_values.is_undefined, values))
    if undefined:
        raise ValueError('{} must be defined before the loop.'.format(','.join(
            s.symbol_name for s in undefined)))
    for value in values:
Exemplo n.º 21
0
"""Asset-type Trackable object."""
import os

from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.trackable import base
from tensorflow.python.util import lazy_loader
from tensorflow.python.util.tf_export import tf_export

# TODO(b/205183809): Remove once nested_structure_coder no longer adds
# dependency cycles.
saved_model_utils = lazy_loader.LazyLoader(
    "saved_model_utils", globals(), "tensorflow.python.saved_model.utils_impl")


@tf_export("saved_model.Asset")
class Asset(base.Trackable):
    """Represents a file asset to hermetically include in a SavedModel.

  A SavedModel can include arbitrary files, called assets, that are needed
  for its use. For example a vocabulary file used initialize a lookup table.

  When a trackable object is exported via `tf.saved_model.save()`, all the
  `Asset`s reachable from it are copied into the SavedModel assets directory.
  Upon loading, the assets and the serialized functions that depend on them
  will refer to the correct filepaths inside the SavedModel directory.

  Example:
Exemplo n.º 22
0
# limitations under the License.
# ==============================================================================
"""Python API for save and loading a dataset."""

from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util import lazy_loader
from tensorflow.python.util.tf_export import tf_export

COMPRESSION_GZIP = "GZIP"
COMPRESSION_SNAPPY = "NONE"
DATASET_SPEC_FILENAME = "dataset_spec.pb"
# TODO(b/176933539): Use the regular import.
# TODO(b/238903802): Use TypeSpec serialization methods directly.
nested_structure_coder = lazy_loader.LazyLoader(
    "nested_structure_coder", globals(),
    "tensorflow.python.saved_model.nested_structure_coder")


@tf_export("data.experimental.save", v1=[])
@deprecation.deprecated(None, "Use `tf.data.Dataset.save(...)` instead.")
def save(dataset,
         path,
         compression=None,
         shard_func=None,
         checkpoint_args=None):
    """Saves the content of the given dataset.

  Example usage:

  >>> import tempfile
Exemplo n.º 23
0
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorFlow Probability alternative substrates."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.util import lazy_loader  # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.util.all_util import remove_undocumented  # pylint: disable=g-direct-tensorflow-import

jax = lazy_loader.LazyLoader(
    'jax', globals(),
    'tensorflow_probability.python.experimental.substrates.jax')
numpy = lazy_loader.LazyLoader(
    'numpy', globals(),
    'tensorflow_probability.python.experimental.substrates.numpy')

_allowed_symbols = [
    'jax',
    'numpy',
]

remove_undocumented(__name__, _allowed_symbols)