# registrations) -- which must be loaded when deserializing tensorflow # saved models. _maybe_nonlazy_load = [ 'experimental', 'layers', ] def _tf_loaded(): return 'compat' in dir(sys.modules.get('tensorflow', None)) # To start with, lazy-load everything. Later we may replace some of the # lazy-loaded modules by forcing a load. for pkg_name in _lazy_load + _maybe_nonlazy_load: globals()[pkg_name] = lazy_loader.LazyLoader( pkg_name, globals(), 'tensorflow_probability.python.{}'.format(pkg_name), # These checks need to happen before lazy-loading, since the modules # themselves will try to import tensorflow, too. on_first_access=functools.partial(_validate_tf_environment, pkg_name)) if _tf_loaded(): # Non-lazy load of packages that register with tensorflow or keras. for pkg_name in _maybe_nonlazy_load: dir(globals() [pkg_name]) # Forces loading the package from its lazy loader. all_util.remove_undocumented(__name__, _lazy_load + _maybe_nonlazy_load)
required=required_tensorflow_version, present=tf.__version__)) _allowed_symbols = [ 'bijectors', 'debugging', 'distributions', 'edward2', 'experimental', 'glm', 'layers', 'math', 'mcmc', 'monte_carlo', 'optimizer', 'random', 'stats', 'sts', 'util', 'vi', ] for pkg in _allowed_symbols: globals()[pkg] = lazy_loader.LazyLoader( pkg, globals(), 'tensorflow_probability.python.{}'.format(pkg), on_first_access=_ensure_tf_install) all_util.remove_undocumented(__name__, _allowed_symbols)
# Copyright 2019 The TensorFlow Probability Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """TensorFlow Probability alternative substrates.""" from tensorflow_probability.python.internal import all_util from tensorflow_probability.python.internal import lazy_loader # pylint: disable=g-direct-tensorflow-import jax = lazy_loader.LazyLoader('jax', globals(), 'tensorflow_probability.substrates.jax') numpy = lazy_loader.LazyLoader('numpy', globals(), 'tensorflow_probability.substrates.numpy') _allowed_symbols = [ 'jax', 'numpy', ] all_util.remove_undocumented(__name__, _allowed_symbols)