def _vlog(*_, **__):  # pylint: disable=unused-argument
  pass


def _warn(*_, **__):  # pylint: disable=unused-argument
  pass


def _warning(*_, **__):  # pylint: disable=unused-argument
  pass


# --- Begin Public Functions --------------------------------------------------

TaskLevelStatusMessage = utils.copy_docstring(  # pylint: disable=invalid-name
    tf1.logging.TaskLevelStatusMessage,
    _TaskLevelStatusMessage)

debug = utils.copy_docstring(
    tf1.logging.debug,
    _debug)

error = utils.copy_docstring(
    tf1.logging.error,
    _error)

fatal = utils.copy_docstring(
    tf1.logging.fatal,
    _fatal)

flush = utils.copy_docstring(
Esempio n. 2
0
  del name
  value = np.array(value)
  return list(
      np.squeeze(x, axis=axis)
      for x in np.split(value, value.shape[axis] if num is None else num, axis))


def _zeros_like(input, dtype=None, name=None):  # pylint: disable=redefined-builtin,unused-argument
  return np.zeros_like(input, dtype=utils.numpy_dtype(dtype))


# --- Begin Public Functions --------------------------------------------------


concat = utils.copy_docstring(
    'tf.concat',
    _concat)


expand_dims = utils.copy_docstring(
    'tf.expand_dims',
    lambda input, axis, name=None: np.expand_dims(input, axis))

fill = utils.copy_docstring(
    'tf.fill',
    lambda dims, value, name=None: np.full(dims, ops.convert_to_tensor(value)))

gather = utils.copy_docstring(
    'tf.gather',
    _gather)
Esempio n. 3
0
    pass

  def watch(self, tensor):  # pylint: disable=unused-argument
    pass

  def gradient(self, target, sources, output_gradients=None,  # pylint: disable=unused-argument
               unconnected_gradients=None):  # pylint: disable=unused-argument
    raise NotImplementedError

  def batch_jacobian(self, target, source,  # pylint: disable=unused-argument
                     unconnected_gradients=None,  # pylint: disable=unused-argument
                     parallel_iterations=None, experimental_use_pfor=True):  # pylint: disable=unused-argument
    raise NotImplementedError

bitcast = utils.copy_docstring(
    'tf.bitcast',
    lambda input, type, name=None: convert_to_tensor(  # pylint: disable=g-long-lambda
        input, dtype_hint=type).view(type))

broadcast_dynamic_shape = utils.copy_docstring(
    'tf.broadcast_dynamic_shape', _broadcast_dynamic_shape)

broadcast_static_shape = utils.copy_docstring(
    'tf.broadcast_static_shape', _broadcast_static_shape)

broadcast_to = utils.copy_docstring(
    'tf.broadcast_to',
    lambda input, shape, name=None: np.broadcast_to(input, shape))


def _cast(x, dtype):
  x = np.asarray(x)
Esempio n. 4
0
def _assert_rank_at_least(x, rank, message=None, name=None):
    del name
    if len(x.shape) < rank:
        raise ValueError(
            'Expected rank at least {} but got shape {} {}'.format(
                rank, x.shape, message or ''))


def _assert_rank_in(*_, **__):  # pylint: disable=unused-argument
    pass


# --- Begin Public Functions --------------------------------------------------

Assert = utils.copy_docstring(  # pylint: disable=invalid-name
    tf.debugging.Assert,
    lambda condition, data, summarize=None, name=None: None)

assert_equal = utils.copy_docstring(tf.debugging.assert_equal, _assert_equal)

assert_greater = utils.copy_docstring(tf.debugging.assert_greater,
                                      _assert_greater)

assert_less = utils.copy_docstring(tf.debugging.assert_less, _assert_less)

assert_rank = utils.copy_docstring(tf.debugging.assert_rank, _assert_rank)

assert_scalar = utils.copy_docstring(tf.debugging.assert_scalar,
                                     _assert_scalar)

assert_greater_equal = utils.copy_docstring(tf.debugging.assert_greater_equal,
Esempio n. 5
0
    else:
        maxval = dtype(1) if maxval is None else maxval
        shape = _bcast_shape(shape, [minval, maxval])
        # We must match ranks, as lax.max refuses to broadcast different-rank args.
        minval = minval + np.zeros([1] * final_rank, dtype=dtype)
        maxval = maxval + np.zeros([1] * final_rank, dtype=dtype)
        return jaxrand.uniform(key=seed,
                               shape=shape,
                               dtype=dtype,
                               minval=minval,
                               maxval=maxval)


# --- Begin Public Functions --------------------------------------------------

categorical = utils.copy_docstring(
    tf.random.categorical, _categorical_jax if JAX_MODE else _categorical)

gamma = utils.copy_docstring(tf.random.gamma,
                             _gamma_jax if JAX_MODE else _gamma)

normal = utils.copy_docstring(tf.random.normal,
                              _normal_jax if JAX_MODE else _normal)

poisson = utils.copy_docstring(tf.random.poisson,
                               _poisson_jax if JAX_MODE else _poisson)

shuffle = utils.copy_docstring(tf.random.shuffle,
                               _shuffle_jax if JAX_MODE else _shuffle)

uniform = utils.copy_docstring(tf.random.uniform,
                               _uniform_jax if JAX_MODE else _uniform)
Esempio n. 6
0

def _rfftn(x, s, axes):
  x = ops.convert_to_tensor(x)
  complex_dtype = np.result_type(np.complex64, x.dtype)
  return np.fft.rfftn(x, s=s, axes=axes).astype(complex_dtype)


def _irfftn(x, s, axes):
  x = ops.convert_to_tensor(x)
  float_dtype = np.finfo(x.dtype).dtype
  return np.fft.irfftn(x, s=s, axes=axes).astype(float_dtype)


fft = utils.copy_docstring(
    'tf.signal.fft',
    lambda input, name=None: _fftn(input, axes=[-1]))

fft2d = utils.copy_docstring(
    'tf.signal.fft2d',
    lambda input, name=None: _fftn(input, axes=[-2, -1]))

fft3d = utils.copy_docstring(
    'tf.signal.fft3d',
    lambda input, name=None: _fftn(input, axes=[-3, -2, -1]))

ifft = utils.copy_docstring(
    'tf.signal.ifft',
    lambda input, name=None: _ifftn(input, axes=[-1]))

ifft2d = utils.copy_docstring(
Esempio n. 7
0
  x = np.array(x)
  if len(x.shape) != 1:
    raise tf.errors.InvalidArgumentError('unique expects a 1D vector.')
  y, idx = np.unique(x,
                     return_index=True,
                     return_inverse=False,
                     return_counts=False,
                     axis=None)
  idx = idx.astype(utils.numpy_dtype(out_idx))
  return _UniqueOutput(y=y, idx=idx)


# --- Begin Public Functions --------------------------------------------------

argsort = utils.copy_docstring(
    tf.argsort,
    _argsort)

sort = utils.copy_docstring(
    tf.sort,
    _sort)

tensor_scatter_nd_add = utils.copy_docstring(
    tf.tensor_scatter_nd_add,
    _tensor_scatter_nd_add)

tensor_scatter_nd_sub = utils.copy_docstring(
    tf.tensor_scatter_nd_sub,
    _tensor_scatter_nd_sub)

tensor_scatter_nd_update = utils.copy_docstring(
Esempio n. 8
0
  if maximum_iterations is None:
    def override_body_fn(args):
      return body(*args)
    def override_cond_fn(args):
      return cond(*args)
    return lax.while_loop(override_cond_fn, override_body_fn, loop_vars)
  else:  # Use else to avoid linter saying these functions are already defined.
    def override_body_fn(args):
      i, args = args
      return i + 1, body(*args)
    def override_cond_fn(args):
      i, args = args
      return cond(*args) & (i < maximum_iterations)
    return lax.while_loop(
        override_cond_fn, override_body_fn, (np.array(0), loop_vars))[1]


# --- Begin Public Functions --------------------------------------------------

cond = utils.copy_docstring(
    tf.cond,
    _cond_jax if JAX_MODE else _cond)

no_op = utils.copy_docstring(
    tf.no_op,
    _no_op)

while_loop = utils.copy_docstring(
    tf.while_loop,
    _while_loop_jax if JAX_MODE else _while_loop)
Esempio n. 9
0
def _complex(real, imag, name=None):  # pylint: disable=unused-argument
    dtype = utils.common_dtype([real, imag], dtype_hint=float32)
    real = np.array(real, dtype=dtype)
    imag = np.array(imag, dtype=dtype)
    if as_dtype(dtype) == float32:
        complex_dtype = complex64
    else:
        complex_dtype = complex128
    return real + imag * complex_dtype(1j)


# --- Begin Public Functions --------------------------------------------------

as_dtype = utils.copy_docstring(
    'tf.as_dtype',
    lambda type_value: np.dtype(  # pylint: disable=g-long-lambda
        type_value.name if hasattr(type_value, 'name') else type_value).type)

real_dtype = lambda dtype: np.real(np.zeros(
    (0, ), dtype=as_dtype(dtype))).dtype

bool = np.bool_  # pylint: disable=redefined-builtin

complex = utils.copy_docstring('tf.complex', _complex)  # pylint: disable=redefined-builtin

complex128 = np.complex128

complex64 = np.complex64

double = np.double
Esempio n. 10
0

class _SingleReplicaContext(object):
    """Dummy replica context for numpy."""
    @property
    def replica_id_in_sync_group(self):
        if JAX_MODE:
            raise NotImplementedError
        return 0

    @property
    def num_replicas_in_sync(self):
        if JAX_MODE:
            raise NotImplementedError
        return 1


# --- Begin Public Functions --------------------------------------------------

compat = collections.namedtuple('compat', 'dimension_value')(dimension_value)

distribute = collections.namedtuple(
    'distribute', 'get_replica_context')(_SingleReplicaContext)

function = utils.copy_docstring('tf.function', _function)

eye = linalg.eye
matmul = linalg.matmul

del collections, utils
Esempio n. 11
0
    # because logsumexp is often used.
    m = _max_mask_non_finite(input_tensor, axis=axis, keepdims=True)
    y = input_tensor - m
    y = np.exp(y, out=y)
    return m + np.log(np.sum(y, axis=_astuple(axis), keepdims=keepdims))


def _top_k(input, k=1, sorted=True, name=None):  # pylint: disable=unused-argument,redefined-builtin
  raise NotImplementedError


# --- Begin Public Functions --------------------------------------------------


abs = utils.copy_docstring(  # pylint: disable=redefined-builtin
    tf.math.abs,
    lambda x, name=None: np.abs(x))

accumulate_n = utils.copy_docstring(
    tf.math.accumulate_n,
    lambda inputs, shape=None, tensor_dtype=None, name=None: (  # pylint: disable=g-long-lambda
        sum(map(np.array, inputs)).astype(utils.numpy_dtype(tensor_dtype))))

acos = utils.copy_docstring(
    tf.math.acos,
    lambda x, name=None: np.arccos(x))

acosh = utils.copy_docstring(
    tf.math.acosh,
    lambda x, name=None: np.arccosh(x))
Esempio n. 12
0
    if JAX_MODE:
        min_z = (minvals - means) / stddevs
        max_z = (maxvals - means) / stddevs

        min_z = _right_expand(min_z, shape)
        max_z = _right_expand(max_z, shape)
        means = _right_expand(means, shape)
        stddevs = _right_expand(stddevs, shape)

        z = random.truncated_normal(seed,
                                    lower=min_z,
                                    upper=max_z,
                                    shape=shape,
                                    dtype=dtype)
        return z * stddevs + means

    raise NotImplementedError


parameterized_truncated_normal = utils.copy_docstring(
    'random_ops.parameterized_truncated_normal',
    _parameterized_truncated_normal)


def _prevent_gradient(input, message='', name=None):  # pylint: disable=unused-argument,redefined-builtin
    raise NotImplementedError


prevent_gradient = utils.copy_docstring('array_ops.prevent_gradient',
                                        _prevent_gradient)
Esempio n. 13
0
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Experimental Numpy backend."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports

import tensorflow as tf

from tensorflow_probability.python.internal.backend.numpy import _utils as utils


__all__ = [
    'Assert',
    'check_numerics',
]


Assert = utils.copy_docstring(  # pylint: disable=invalid-name
    tf.debugging.Assert,
    lambda condition, data, summarize=None, name=None: None)


check_numerics = utils.copy_docstring(
    tf.debugging.check_numerics,
    lambda x, *_, **__: x
)
Esempio n. 14
0
  x = convert_to_tensor(input)
  if hasattr(shape, 'as_list'):
    shape = shape.as_list()
  if shape is None or any(s is None for s in shape):
    return x
  return np.reshape(x, shape)


# --- Begin Public Functions --------------------------------------------------


matrix_determinant = linalg_impl.det
matrix_solve = linalg_impl.solve

colocate_with = utils.copy_docstring(
    tf1.colocate_with,
    _dummy_scope)

control_flow_v2_enabled = utils.copy_docstring(
    tf1.control_flow_v2_enabled,
    lambda: True)

get_variable = utils.copy_docstring(
    tf1.get_variable,
    _get_variable)

get_variable_scope = utils.copy_docstring(
    tf1.get_variable_scope,
    lambda: variable_scope(name_or_scope=None))

placeholder_with_default = utils.copy_docstring(
Esempio n. 15
0
    nbatch = int(np.prod(matrix.shape[:-2]))
    flat_mat = matrix.reshape(nbatch, dim, dim)
    flat_rhs = rhs.reshape(nbatch, dim, rhs.shape[-1])
    result = np.empty(flat_rhs.shape)
    if np.size(result):
        # ValueError: On entry to STRTRS parameter number 7 had an illegal value.
        for i, (mat, rh) in enumerate(zip(flat_mat, flat_rhs)):
            result[i] = scipy_linalg.solve_triangular(
                mat, rh, lower=lower, trans='C' if adjoint else 'N')
    return result.reshape(*rhs.shape)


# --- Begin Public Functions --------------------------------------------------

adjoint = utils.copy_docstring(
    tf.linalg.adjoint,
    lambda matrix, name=None: _matrix_transpose(matrix, conjugate=True))

band_part = utils.copy_docstring(tf.linalg.band_part, _band_part)

cholesky = utils.copy_docstring(
    tf.linalg.cholesky, lambda input, name=None: np.linalg.cholesky(input))

cholesky_solve = utils.copy_docstring(tf.linalg.cholesky_solve,
                                      _cholesky_solve)

det = utils.copy_docstring(tf.linalg.det,
                           lambda input, name=None: np.linalg.det(input))

diag = utils.copy_docstring(tf.linalg.diag, _diag)
Esempio n. 16
0
    logits,
    axis=-1,
    name=None):
  """Softmax cross entropy with logits."""
  cost = -np.sum(
      np.where(labels == 0, np.zeros_like(labels),
               labels * (logits - reduce_logsumexp(
                   logits, axis=axis, keepdims=True))),
      axis=axis)
  return cost.astype(logits.dtype)


# --- Begin Public Functions --------------------------------------------------

l2_normalize = utils.copy_docstring(
    'tf.nn.l2_normalize',
    l2_normalize)


def _moments(x, axes, shift=None, keepdims=False, name=None):  # pylint: disable=unused-argument
  # NOTE: If x.dtype is float16, we may want to compute in float32.
  mean = reduce_mean(x, axis=axes, keepdims=True)
  # NOTE: The gradient backpropagated to the mean from the variance calcuation
  # is zero, so we can safely use `stop_gradient(mean)` for efficiency.
  variance = reduce_mean(squared_difference(x, stop_gradient(mean)),
                         axis=axes, keepdims=keepdims)
  if not keepdims:
    mean = numpy_array.squeeze(mean, axes)
  return (mean, variance)

moments = utils.copy_docstring(
Esempio n. 17
0
  nbatch = int(np.prod(matrix.shape[:-2]))
  flat_mat = matrix.reshape(nbatch, dim, dim)
  flat_rhs = rhs.reshape(nbatch, dim, rhs.shape[-1])
  result = np.empty(flat_rhs.shape)
  if np.size(result):
    # ValueError: On entry to STRTRS parameter number 7 had an illegal value.
    for i, (mat, rh) in enumerate(zip(flat_mat, flat_rhs)):
      result[i] = scipy_linalg.solve_triangular(mat, rh, lower=lower,
                                                trans='C' if adjoint else 'N')
  return result.reshape(*rhs.shape)


# --- Begin Public Functions --------------------------------------------------

adjoint = utils.copy_docstring(
    'tf.linalg.adjoint',
    lambda matrix, name=None: _matrix_transpose(matrix, conjugate=True))

band_part = utils.copy_docstring(
    'tf.linalg.band_part',
    _band_part)

cholesky = utils.copy_docstring(
    'tf.linalg.cholesky',
    lambda input, name=None: np.linalg.cholesky(input))

cholesky_solve = utils.copy_docstring(
    'tf.linalg.cholesky_solve',
    _cholesky_solve)

det = utils.copy_docstring(
Esempio n. 18
0
    if x is None and y is None:
        return np.where(condition)
    dtype = utils.common_dtype([x, y])
    x = convert_to_tensor(x, dtype=dtype)
    y = convert_to_tensor(y, dtype=dtype)
    while condition.ndim < max(x.ndim, y.ndim):
        condition = condition[..., np.newaxis]  # add 1 dims until rank matches
    return np.where(condition, x, y)


# --- Begin Public Functions --------------------------------------------------

matrix_determinant = linalg_impl.det
matrix_solve = linalg_impl.solve

colocate_with = utils.copy_docstring('tf1.colocate_with', _dummy_scope)

control_flow_v2_enabled = utils.copy_docstring('tf1.control_flow_v2_enabled',
                                               lambda: True)

enable_control_flow_v2 = utils.copy_docstring('tf1.enable_control_flow_v2',
                                              lambda: None)

get_variable = utils.copy_docstring('tf1.get_variable', _get_variable)

get_variable_scope = utils.copy_docstring(
    'tf1.get_variable_scope', lambda: variable_scope(name_or_scope=None))

placeholder_with_default = utils.copy_docstring('tf1.placeholder_with_default',
                                                _placeholder_with_default)
Esempio n. 19
0
    cost = -np.sum(np.where(
        labels == 0, np.zeros_like(labels),
        labels * (logits - reduce_logsumexp(logits, axis=-1, keepdims=True))),
                   axis=-1)
    cost = np.reshape(cost, labels_shape)
    return cost


def _softmax_cross_entropy_with_logits(labels, logits, name=None):  # pylint: disable=unused-argument
    raise NotImplementedError


# --- Begin Public Functions --------------------------------------------------

l2_normalize = utils.copy_docstring(tf.nn.l2_normalize, l2_normalize)


def _moments(x, axes, shift=None, keepdims=False, name=None):  # pylint: disable=unused-argument
    # NOTE: If x.dtype is float16, we may want to compute in float32.
    mean = reduce_mean(x, axis=axes, keepdims=True)
    # NOTE: The gradient backpropagated to the mean from the variance calcuation
    # is zero, so we can safely use `stop_gradient(mean)` for efficiency.
    variance = reduce_mean(squared_difference(x, stop_gradient(mean)),
                           axis=axes,
                           keepdims=keepdims)
    if not keepdims:
        mean = numpy_array.squeeze(mean, axes)
    return (mean, variance)

Esempio n. 20
0
def _assert_rank_at_least(x, rank, message=None, name=None):
    del name
    if len(x.shape) < rank:
        raise ValueError(
            'Expected rank at least {} but got shape {} {}'.format(
                rank, x.shape, message or ''))


def _assert_rank_in(*_, **__):  # pylint: disable=unused-argument
    pass


# --- Begin Public Functions --------------------------------------------------

Assert = utils.copy_docstring(  # pylint: disable=invalid-name
    'tf.debugging.Assert',
    lambda condition, data, summarize=None, name=None: assert_equal(  # pylint: disable=g-long-lambda
        True, condition, message=data))

assert_all_finite = utils.copy_docstring('tf.debugging.assert_all_finite',
                                         _assert_all_finite)

assert_equal = utils.copy_docstring('tf.debugging.assert_equal', _assert_equal)

assert_greater = utils.copy_docstring('tf.debugging.assert_greater',
                                      _assert_greater)

assert_less = utils.copy_docstring('tf.debugging.assert_less', _assert_less)

assert_rank = utils.copy_docstring('tf.debugging.assert_rank', _assert_rank)

assert_scalar = utils.copy_docstring('tf.debugging.assert_scalar',
Esempio n. 21
0
def _argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None):  # pylint: disable=unused-argument
    """Numpy implementation of `tf.argsort`."""
    if direction == 'ASCENDING':
        pass
    elif direction == 'DESCENDING':
        values = np.negative(values)
    else:
        raise ValueError('Unrecognized direction: {}.'.format(direction))
    return np.argsort(values, axis, kind='stable' if stable else 'quicksort')


def _sort(values, axis=-1, direction='ASCENDING', stable=False, name=None):  # pylint: disable=unused-argument
    """Numpy implementation of `tf.sort`."""
    if direction == 'ASCENDING':
        pass
    elif direction == 'DESCENDING':
        values = np.negative(values)
    else:
        raise ValueError('Unrecognized direction: {}.'.format(direction))
    result = np.sort(values, axis, kind='stable' if stable else 'quicksort')
    if direction == 'DESCENDING':
        return np.negative(result)
    return result


# --- Begin Public Functions --------------------------------------------------

argsort = utils.copy_docstring(tf.argsort, _argsort)

sort = utils.copy_docstring(tf.sort, _sort)
Esempio n. 22
0
from tensorflow_probability.python.internal.backend.numpy import _utils as utils

__all__ = [
    'MatrixDiagPartV2',
]

JAX_MODE = False


def _matrix_diag_part_v2(input, k, padding_value, name=None):  # pylint: disable=redefined-builtin,unused-argument
    """Implements tf.raw_ops.MatrixDiagPartV2, for scalar k."""
    if np.array(k).ndim > 0:
        raise NotImplementedError
    shp = np.shape(input)

    if JAX_MODE:
        if len(shp) > 2:
            from jax import vmap  # pylint: disable=g-import-not-at-top
            return vmap(_matrix_diag_part_v2, (0, None, None))(input, k,
                                                               padding_value)
        return np.diag(input, k=k)

    input = np.reshape(input, (-1, shp[-2], shp[-1]))
    output = np.array([np.diag(arr, k=k) for arr in input])
    return output.reshape(*(shp[:-2] + output.shape[1:]))


MatrixDiagPartV2 = utils.copy_docstring(  # pylint: disable=invalid-name
    'tf.raw_ops.MatrixDiagPartV2', _matrix_diag_part_v2)
Esempio n. 23
0
      out.append(arg)
    out = nest.map_structure(lambda *x: np.stack(x, axis=0), *out)

  if prepend is not None:
    out = nest.map_structure(
        lambda p, o: np.concatenate([p[np.newaxis], o], axis=0), prepend, out)

  ordering = (lambda x: x[::-1]) if reverse else (lambda x: x)
  return nest.map_structure(ordering, out, expand_composites=True)


# --- Begin Public Functions --------------------------------------------------


foldl = utils.copy_docstring(
    'tf.foldl',
    _foldl_jax if JAX_MODE else _foldl)

map_fn = utils.copy_docstring(
    'tf.map_fn',
    _map_fn)

vectorized_map = utils.copy_docstring(
    'tf.vectorized_map',
    _vectorized_map)


def pfor(fn, n):
  if JAX_MODE:
    import jax  # pylint: disable=g-import-not-at-top
    return jax.vmap(fn)(np.arange(n))
Esempio n. 24
0
    # 'dynamic_stitch',
    # 'map_fn',
    # 'scan',
]


def _no_op(_):
  pass


def _while_loop(cond, body, loop_vars,
                shape_invariants=None, parallel_iterations=10,  # pylint: disable=unused-argument
                back_prop=True, swap_memory=False,  # pylint: disable=unused-argument
                maximum_iterations=None, name=None):  # pylint: disable=unused-argument
  i = 0
  while (cond(*loop_vars) and
         (maximum_iterations is None or i < maximum_iterations)):
    loop_vars = body(*loop_vars)
    i += 1
  return loop_vars

# --- Begin Public Functions --------------------------------------------------

no_op = utils.copy_docstring(
    tf.no_op,
    _no_op)

while_loop = utils.copy_docstring(
    tf.while_loop,
    _while_loop)
      for i, z in enumerate(arg):
        out[i].append(z)

  if prepend is not None:
    out = [pre + list(o) for (pre, o) in zip(prepend, out)]

  ordering = (lambda x: x[::-1]) if reverse else (lambda x: x)
  return nest.pack_sequence_as(
      initializer, [ordering(np.array(o)) for o in out])


# --- Begin Public Functions --------------------------------------------------


map_fn = utils.copy_docstring(
    'tf.map_fn',
    _map_fn)

vectorized_map = utils.copy_docstring(
    'tf.vectorized_map',
    _vectorized_map)


def pfor(fn, n):
  if JAX_MODE:
    import jax  # pylint: disable=g-import-not-at-top
    return jax.vmap(fn)(np.arange(n))
  outs = [fn(i) for i in range(n)]
  flat_outs = [nest.flatten(o) for o in outs]
  return nest.pack_sequence_as(
      outs[0], [np.array(o) for o in zip(*flat_outs)])
Esempio n. 26
0
                           dtype=dtype)
  else:
    maxval = dtype(1) if maxval is None else maxval
    shape = _bcast_shape(shape, [minval, maxval])
    # We must match ranks, as lax.max refuses to broadcast different-rank args.
    minval = minval + np.zeros([1] * final_rank, dtype=dtype)
    maxval = maxval + np.zeros([1] * final_rank, dtype=dtype)
    return jaxrand.uniform(key=seed, shape=shape, dtype=dtype, minval=minval,
                           maxval=maxval)


# --- Begin Public Functions --------------------------------------------------


stateless_categorical = utils.copy_docstring(
    tf.random.stateless_categorical,
    _categorical_jax if JAX_MODE else _categorical)

stateless_gamma = utils.copy_docstring(
    tf.random.stateless_gamma,
    _gamma_jax if JAX_MODE else _gamma)


# TODO(b/147874898): Delete this method.
def gamma(shape, alpha, beta=None, dtype=tf.float32, seed=None, name=None):
  """Handles the difference in shape parameter interpretation."""
  # While we still have usages of tf.random.gamma and tf.random.stateless_gamma,
  # we must handle the different interpretation of the shape argument
  # between the two. `tf.random.gamma` interprets shape as a prefix.
  # `tf.random.stateless_gamma` interprets shape as the full output shape,
  # including as suffix the broadcast of alpha and beta shapes.
Esempio n. 27
0
    nbatch = int(np.prod(matrix.shape[:-2]))
    flat_mat = matrix.reshape(nbatch, dim, dim)
    flat_rhs = rhs.reshape(nbatch, dim, rhs.shape[-1])
    result = np.empty(flat_rhs.shape)
    if np.size(result):
        # ValueError: On entry to STRTRS parameter number 7 had an illegal value.
        for i, (mat, rh) in enumerate(zip(flat_mat, flat_rhs)):
            result[i] = scipy_linalg.solve_triangular(
                mat, rh, lower=lower, trans='C' if adjoint else 'N')
    return result.reshape(*rhs.shape)


# --- Begin Public Functions --------------------------------------------------

adjoint = utils.copy_docstring(
    'tf.linalg.adjoint',
    lambda matrix, name=None: _matrix_transpose(matrix, conjugate=True))

band_part = utils.copy_docstring('tf.linalg.band_part', _band_part)

cholesky = utils.copy_docstring(
    'tf.linalg.cholesky', lambda input, name=None: np.linalg.cholesky(input))

cholesky_solve = utils.copy_docstring('tf.linalg.cholesky_solve',
                                      _cholesky_solve)

det = utils.copy_docstring('tf.linalg.det',
                           lambda input, name=None: np.linalg.det(input))

diag = utils.copy_docstring('tf.linalg.diag', _diag)
Esempio n. 28
0
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Numpy bitwise ops."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow_probability.python.internal.backend.numpy import _utils as utils

__all__ = [
    'bitwise_xor',
    'left_shift',
]

bitwise_xor = utils.copy_docstring(
    'tf.bitwise.bitwise_xor', lambda x, y, name=None: np.bitwise_xor(x, y))  # pylint: disable=unused-argument

left_shift = utils.copy_docstring('tf.bitwise.left_shift',
                                  lambda x, y, name=None: np.left_shift(x, y))  # pylint: disable=unused-argument
Esempio n. 29
0
        sources,
        output_gradients=None,  # pylint: disable=unused-argument
        unconnected_gradients=UnconnectedGradients.NONE):  # pylint: disable=unused-argument
        return sources

    def batch_jacobian(
            self,
            target,
            source,  # pylint: disable=unused-argument
            unconnected_gradients=UnconnectedGradients.NONE,  # pylint: disable=unused-argument
            parallel_iterations=None,
            experimental_use_pfor=True):  # pylint: disable=unused-argument
        return source


broadcast_dynamic_shape = utils.copy_docstring(tf.broadcast_dynamic_shape,
                                               _broadcast_static_shape)

broadcast_static_shape = utils.copy_docstring(tf.broadcast_static_shape,
                                              _broadcast_static_shape)

broadcast_to = utils.copy_docstring(
    tf.broadcast_to,
    lambda input, shape, name=None: np.broadcast_to(input, shape))

cast = utils.copy_docstring(
    tf.cast,
    lambda x, dtype, name=None: np.array(x).astype(utils.numpy_dtype(dtype)))

clip_by_value = utils.copy_docstring(
    tf.clip_by_value,
    lambda t, clip_value_min, clip_value_max, name=None:  # pylint: disable=g-long-lambda
Esempio n. 30
0
    return (-np.sort(-input, axis=-1)[..., :k], (n - (np.argsort(
        input[..., ::-1], kind='stable', axis=-1)[..., ::-1]))[..., :k])


def _unsorted_segment_sum(data, segment_ids, num_segments, name=None):
    del name
    if not JAX_MODE:
        raise NotImplementedError
    sums = np.zeros(num_segments)
    return jax.ops.index_add(sums, jax.ops.index[segment_ids], data)


# --- Begin Public Functions --------------------------------------------------

abs = utils.copy_docstring(  # pylint: disable=redefined-builtin
    'tf.math.abs',
    lambda x, name=None: np.abs(x))

accumulate_n = utils.copy_docstring(
    'tf.math.accumulate_n',
    lambda inputs, shape=None, tensor_dtype=None, name=None: (  # pylint: disable=g-long-lambda
        sum(map(np.array, inputs)).astype(utils.numpy_dtype(tensor_dtype))))

acos = utils.copy_docstring('tf.math.acos', lambda x, name=None: np.arccos(x))

acosh = utils.copy_docstring('tf.math.acosh',
                             lambda x, name=None: np.arccosh(x))

add = utils.copy_docstring('tf.math.add', lambda x, y, name=None: np.add(x, y))

add_n = utils.copy_docstring(