Ejemplo n.º 1
0
      validate_shape=validate_shape, caching_device=caching_device, name=name,
      variable_def=None, dtype=dtype, import_scope=None, constraint=None)


def _placeholder_with_default(input, shape, name=None):  # pylint: disable=redefined-builtin,unused-argument
  x = np.array(input)
  if shape is None or any(s is None for s in shape):
    return x
  return np.reshape(x, shape)


# --- Begin Public Functions --------------------------------------------------


assert_equal = utils.copy_docstring(
    tf1.assert_equal,
    _assert_equal)

assert_greater = utils.copy_docstring(
    tf1.assert_greater,
    _assert_greater)

assert_less = utils.copy_docstring(
    tf1.assert_less,
    _assert_less)

assert_rank = utils.copy_docstring(
    tf1.assert_rank,
    _assert_rank)

assert_scalar = utils.copy_docstring(
Ejemplo n.º 2
0
    nbatch = int(np.prod(matrix.shape[:-2]))
    flat_mat = matrix.reshape(nbatch, dim, dim)
    flat_rhs = rhs.reshape(nbatch, dim, rhs.shape[-1])
    result = np.empty(flat_rhs.shape)
    if np.size(result):
        # ValueError: On entry to STRTRS parameter number 7 had an illegal value.
        for i, (mat, rh) in enumerate(zip(flat_mat, flat_rhs)):
            result[i] = scipy_linalg.solve_triangular(
                mat, rh, lower=lower, trans='C' if adjoint else 'N')
    return result.reshape(*rhs.shape)


# --- Begin Public Functions --------------------------------------------------

adjoint = utils.copy_docstring(
    tf.linalg.adjoint,
    lambda matrix, name=None: _matrix_transpose(matrix, conjugate=True))

band_part = utils.copy_docstring(tf.linalg.band_part, _band_part)

cholesky = utils.copy_docstring(
    tf.linalg.cholesky, lambda input, name=None: np.linalg.cholesky(input))

cholesky_solve = utils.copy_docstring(tf.linalg.cholesky_solve,
                                      _cholesky_solve)

det = utils.copy_docstring(tf.linalg.det,
                           lambda input, name=None: np.linalg.det(input))

diag = utils.copy_docstring(tf.linalg.diag, _diag)
Ejemplo n.º 3
0
from __future__ import division
from __future__ import print_function

# Dependency imports
import numpy as onp
import jax.numpy as np

import tensorflow.compat.v2 as tf

from tensorflow_probability.python.internal.backend.jax import _utils as utils

__all__ = [
    'difference',
]


def _difference(a, b, aminusb=True, validate_indices=True):
    if not aminusb:
        raise NotImplementedError(
            'Argument `aminusb != True` is currently unimplemented.')
    if not validate_indices:
        raise NotImplementedError(
            'Argument `validate_indices != True` is currently unimplemented.')
    return np.setdiff1d(a, b)


# --- Begin Public Functions --------------------------------------------------

# TODO(b/136555907): Add unit test.
difference = utils.copy_docstring(tf.sets.difference, _difference)
Ejemplo n.º 4
0
import jax.numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.internal.backend.jax import _utils as utils

__all__ = [
    'fft',
    'fft2d',
    'fft3d',
    'ifft',
    'ifft2d',
    'ifft3d',
]


fft = utils.copy_docstring(
    tf.signal.fft,
    lambda input, name=None: np.fft.fftn(input, axes=[-1]))

fft2d = utils.copy_docstring(
    tf.signal.fft2d,
    lambda input, name=None: np.fft.fftn(input, axes=[-2, -1]))

fft3d = utils.copy_docstring(
    tf.signal.fft3d,
    lambda input, name=None: np.fft.fftn(input, axes=[-3, -2, -1]))

ifft = utils.copy_docstring(
    tf.signal.ifft,
    lambda input, name=None: np.fft.ifftn(input, axes=[-1]))

ifft2d = utils.copy_docstring(
Ejemplo n.º 5
0
def _argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None):  # pylint: disable=unused-argument
    """Numpy implementation of `tf.argsort`."""
    if direction == 'ASCENDING':
        pass
    elif direction == 'DESCENDING':
        values = np.negative(values)
    else:
        raise ValueError('Unrecognized direction: {}.'.format(direction))
    return np.argsort(values, axis, kind='stable' if stable else 'quicksort')


def _sort(values, axis=-1, direction='ASCENDING', stable=False, name=None):  # pylint: disable=unused-argument
    """Numpy implementation of `tf.sort`."""
    if direction == 'ASCENDING':
        pass
    elif direction == 'DESCENDING':
        values = np.negative(values)
    else:
        raise ValueError('Unrecognized direction: {}.'.format(direction))
    result = np.sort(values, axis, kind='stable' if stable else 'quicksort')
    if direction == 'DESCENDING':
        return np.negative(result)
    return result


# --- Begin Public Functions --------------------------------------------------

argsort = utils.copy_docstring(tf.argsort, _argsort)

sort = utils.copy_docstring(tf.sort, _sort)
Ejemplo n.º 6
0
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Experimental Numpy backend."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports

import tensorflow.compat.v2 as tf

from tensorflow_probability.python.internal.backend.jax import _utils as utils

__all__ = [
    'Assert',
    'check_numerics',
]

Assert = utils.copy_docstring(  # pylint: disable=invalid-name
    tf.debugging.Assert,
    lambda condition, data, summarize=None, name=None: None)

check_numerics = utils.copy_docstring(tf.debugging.check_numerics,
                                      lambda x, *_, **__: x)
Ejemplo n.º 7
0
                 seed=None,
                 name=None):  # pylint: disable=unused-argument
    import jax.random as jaxrand  # pylint: disable=g-import-not-at-top
    if seed is None:
        raise ValueError('Must provide PRNGKey to sample in JAX.')
    dtype = utils.common_dtype([minval, maxval], dtype_hint=dtype)
    maxval = 1 if maxval is None else maxval
    shape = _shape([], shape)
    return jaxrand.uniform(key=seed,
                           shape=shape,
                           dtype=dtype,
                           minval=minval,
                           maxval=maxval)


# --- Begin Public Functions --------------------------------------------------

categorical = utils.copy_docstring(
    tf.random.categorical, _categorical_jax if JAX_MODE else _categorical)

gamma = utils.copy_docstring(tf.random.gamma,
                             _gamma_jax if JAX_MODE else _gamma)

normal = utils.copy_docstring(tf.random.normal,
                              _normal_jax if JAX_MODE else _normal)

poisson = utils.copy_docstring(tf.random.poisson, _poisson)

uniform = utils.copy_docstring(tf.random.uniform,
                               _uniform_jax if JAX_MODE else _uniform)
Ejemplo n.º 8
0
    # 'case',
    # 'cond',
    # 'dynamic_partition',
    # 'dynamic_stitch',
    # 'map_fn',
    # 'scan',
]


def _while_loop(
        cond,
        body,
        loop_vars,
        shape_invariants=None,
        parallel_iterations=10,  # pylint: disable=unused-argument
        back_prop=True,
        swap_memory=False,  # pylint: disable=unused-argument
        maximum_iterations=None,
        name=None):  # pylint: disable=unused-argument
    i = 0
    while (cond(*loop_vars)
           and (maximum_iterations is None or i < maximum_iterations)):
        loop_vars = body(*loop_vars)
        i += 1
    return loop_vars


# --- Begin Public Functions --------------------------------------------------

while_loop = utils.copy_docstring(tf.while_loop, _while_loop)
Ejemplo n.º 9
0
    'uint8',
    # 'as_string',
    # 'bfloat16',
    # 'dtypes',
    # 'qint16',
    # 'qint32',
    # 'qint8',
    # 'quint16',
    # 'quint8',
]


# --- Begin Public Functions --------------------------------------------------

as_dtype = utils.copy_docstring(
    tf.as_dtype,
    lambda type_value: onp.dtype(  # pylint: disable=g-long-lambda
        type_value.name if hasattr(type_value, 'name') else type_value).type)

real_dtype = lambda dtype: np.real(np.zeros((0,), dtype=as_dtype(dtype))).dtype

bool = onp.bool  # pylint: disable=redefined-builtin

complex128 = np.complex128

complex64 = np.complex64

double = np.double

float16 = np.float16

float32 = np.float32
Ejemplo n.º 10
0
def _vlog(*_, **__):  # pylint: disable=unused-argument
  pass


def _warn(*_, **__):  # pylint: disable=unused-argument
  pass


def _warning(*_, **__):  # pylint: disable=unused-argument
  pass


# --- Begin Public Functions --------------------------------------------------

TaskLevelStatusMessage = utils.copy_docstring(  # pylint: disable=invalid-name
    tf1.logging.TaskLevelStatusMessage,
    _TaskLevelStatusMessage)

debug = utils.copy_docstring(
    tf1.logging.debug,
    _debug)

error = utils.copy_docstring(
    tf1.logging.error,
    _error)

fatal = utils.copy_docstring(
    tf1.logging.fatal,
    _fatal)

flush = utils.copy_docstring(
Ejemplo n.º 11
0
        # We offer a non SP version just in case SP isn't installed and this
        # because logsumexp is often used.
        m = _max_mask_non_finite(input_tensor, axis=axis, keepdims=True)
        y = input_tensor - m
        y = np.exp(y, out=y)
        return m + np.log(np.sum(y, axis=_astuple(axis), keepdims=keepdims))


def _top_k(input, k=1, sorted=True, name=None):  # pylint: disable=unused-argument,redefined-builtin
    raise NotImplementedError


# --- Begin Public Functions --------------------------------------------------

abs = utils.copy_docstring(  # pylint: disable=redefined-builtin
    tf.math.abs,
    lambda x, name=None: np.abs(x))

accumulate_n = utils.copy_docstring(
    tf.math.accumulate_n,
    lambda inputs, shape=None, tensor_dtype=None, name=None: (  # pylint: disable=g-long-lambda
        sum(map(np.array, inputs)).astype(utils.numpy_dtype(tensor_dtype))))

acos = utils.copy_docstring(tf.math.acos, lambda x, name=None: np.arccos(x))

acosh = utils.copy_docstring(tf.math.acosh, lambda x, name=None: np.arccosh(x))

add = utils.copy_docstring(tf.math.add, lambda x, y, name=None: np.add(x, y))

add_n = utils.copy_docstring(
    tf.math.add_n, lambda inputs, name=None: sum(map(np.array, inputs)))
Ejemplo n.º 12
0
# ============================================================================
"""Numpy implementations of sparse functions."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf

from tensorflow_probability.python.internal.backend.jax import _utils as utils

__all__ = [
    'to_dense',
]


def _to_dense(sp_input, default_value=0, validate_indices=True, name=None):  # pylint: disable=unused-argument
    if default_value != 0:
        raise NotImplementedError(
            'Argument `default_value != 0` is currently unimplemented.')
    if not validate_indices:
        raise NotImplementedError(
            'Argument `validate_indices != True` is currently unimplemented.')
    return sp_input


# --- Begin Public Functions --------------------------------------------------

# TODO(b/136555907): Add unit test.
to_dense = utils.copy_docstring(tf.sparse.to_dense, _to_dense)
Ejemplo n.º 13
0
def _function(
        func=None,
        input_signature=None,
        autograph=True,  # pylint: disable=unused-argument
        experimental_autograph_options=None,  # pylint: disable=unused-argument
        experimental_relax_shapes=False,
        experimental_compile=None):  # pylint: disable=unused-argument
    """Dummy version of `tf.function`."""
    # This code path is for the `foo = tf.function(foo, ...)` use case.
    if func is not None:
        return func
    # This code path is for the following use case:
    #   @tf.function(...)
    #   def foo(...):
    #      ...
    # This case is equivalent to `foo = tf.function(...)(foo)`.
    return lambda inner_function: inner_function


# --- Begin Public Functions --------------------------------------------------

compat = collections.namedtuple(
    'compat', 'dimension_value')(lambda dim: None if dim is None else int(dim))

function = utils.copy_docstring(tf.function, _function)

eye = linalg.eye
matmul = linalg.matmul

del collections, tf, utils
Ejemplo n.º 14
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Experimental Numpy backend."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports
import numpy as onp
import jax.numpy as np

import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf

from tensorflow_probability.python.internal.backend.jax import _utils as utils

__all__ = [
    'constant',
]

constant = utils.copy_docstring(
    tf1.initializers.constant,
    lambda value=0, dtype=tf.dtypes.float32, verify_shape=False: (  # pylint: disable=g-long-lambda
        lambda shape, dtype=None, partition_info=None, verify_shape=None: (  # pylint: disable=g-long-lambda
            np.ones(shape, dtype=dtype) * value)))
Ejemplo n.º 15
0
            if np.issubdtype(value.dtype, np.integer):
                if not np.issubdtype(dtype_hint, np.integer):
                    return value
        return value.astype(dtype_hint)
    return np.array(value, dtype=utils.numpy_dtype(dtype or dtype_hint))


def _dimension_value(dimension):
    if dimension is None:
        return None
    return int(dimension)


# --- Begin Public Functions --------------------------------------------------

dimension_value = utils.copy_docstring(tf.compat.dimension_value,
                                       _dimension_value)


class GradientTape(object):
    """tf.GradientTape stub."""
    def __init__(self, persistent=False, watch_accessed_variables=True):  # pylint: disable=unused-argument
        pass

    def __enter__(self):
        return self

    def __exit__(self, typ, value, traceback):  # pylint: disable=unused-argument
        pass

    def watch(self, tensor):  # pylint: disable=unused-argument
        pass
Ejemplo n.º 16
0
def _scan(  # pylint: disable=unused-argument
        fn,
        elems,
        initializer=None,
        parallel_iterations=10,
        back_prop=True,
        swap_memory=False,
        infer_shape=True,
        reverse=False,
        name=None):
    """Scan implementation."""
    out = []
    if initializer is None:
        arg = elems[0]
        elems = elems[1:]
    else:
        arg = initializer

    for x in elems:
        arg = fn(arg, x)
        out.append(arg)
    return np.array(out)


# --- Begin Public Functions --------------------------------------------------

map_fn = utils.copy_docstring(tf.map_fn, _map_fn)

scan = utils.copy_docstring(tf.scan, _scan)
Ejemplo n.º 17
0
def _transpose(a, perm=None, conjugate=False, name='transpose'):  # pylint: disable=unused-argument
    x = np.transpose(a, perm)
    return np.conjugate(x) if conjugate else x


def _zeros_like(input, dtype=None, name=None):  # pylint: disable=redefined-builtin
    s = _shape(input)
    if isinstance(s, (np.ndarray, onp.generic)):
        return np.zeros(s, utils.numpy_dtype(dtype or input.dtype))
    return tf.zeros(s, dtype or s.dtype, name)


# --- Begin Public Functions --------------------------------------------------

concat = utils.copy_docstring(
    tf.concat,
    lambda values, axis, name='concat': (  # pylint: disable=g-long-lambda
        np.concatenate([ops.convert_to_tensor(v) for v in values], axis)))

expand_dims = utils.copy_docstring(
    tf.expand_dims, lambda input, axis, name=None: np.expand_dims(input, axis))

fill = utils.copy_docstring(
    tf.fill,
    lambda dims, value, name=None: value * np.ones(dims,
                                                   np.array(value).dtype))

gather = utils.copy_docstring(tf.gather, _gather)

gather_nd = utils.copy_docstring(tf.gather_nd, _gather_nd)

reverse = utils.copy_docstring(tf.reverse, _reverse)