def _vlog(*_, **__): # pylint: disable=unused-argument pass def _warn(*_, **__): # pylint: disable=unused-argument pass def _warning(*_, **__): # pylint: disable=unused-argument pass # --- Begin Public Functions -------------------------------------------------- TaskLevelStatusMessage = utils.copy_docstring( # pylint: disable=invalid-name tf1.logging.TaskLevelStatusMessage, _TaskLevelStatusMessage) debug = utils.copy_docstring( tf1.logging.debug, _debug) error = utils.copy_docstring( tf1.logging.error, _error) fatal = utils.copy_docstring( tf1.logging.fatal, _fatal) flush = utils.copy_docstring(
del name value = np.array(value) return list( np.squeeze(x, axis=axis) for x in np.split(value, value.shape[axis] if num is None else num, axis)) def _zeros_like(input, dtype=None, name=None): # pylint: disable=redefined-builtin,unused-argument return np.zeros_like(input, dtype=utils.numpy_dtype(dtype)) # --- Begin Public Functions -------------------------------------------------- concat = utils.copy_docstring( 'tf.concat', _concat) expand_dims = utils.copy_docstring( 'tf.expand_dims', lambda input, axis, name=None: np.expand_dims(input, axis)) fill = utils.copy_docstring( 'tf.fill', lambda dims, value, name=None: np.full(dims, ops.convert_to_tensor(value))) gather = utils.copy_docstring( 'tf.gather', _gather)
pass def watch(self, tensor): # pylint: disable=unused-argument pass def gradient(self, target, sources, output_gradients=None, # pylint: disable=unused-argument unconnected_gradients=None): # pylint: disable=unused-argument raise NotImplementedError def batch_jacobian(self, target, source, # pylint: disable=unused-argument unconnected_gradients=None, # pylint: disable=unused-argument parallel_iterations=None, experimental_use_pfor=True): # pylint: disable=unused-argument raise NotImplementedError bitcast = utils.copy_docstring( 'tf.bitcast', lambda input, type, name=None: convert_to_tensor( # pylint: disable=g-long-lambda input, dtype_hint=type).view(type)) broadcast_dynamic_shape = utils.copy_docstring( 'tf.broadcast_dynamic_shape', _broadcast_dynamic_shape) broadcast_static_shape = utils.copy_docstring( 'tf.broadcast_static_shape', _broadcast_static_shape) broadcast_to = utils.copy_docstring( 'tf.broadcast_to', lambda input, shape, name=None: np.broadcast_to(input, shape)) def _cast(x, dtype): x = np.asarray(x)
def _assert_rank_at_least(x, rank, message=None, name=None): del name if len(x.shape) < rank: raise ValueError( 'Expected rank at least {} but got shape {} {}'.format( rank, x.shape, message or '')) def _assert_rank_in(*_, **__): # pylint: disable=unused-argument pass # --- Begin Public Functions -------------------------------------------------- Assert = utils.copy_docstring( # pylint: disable=invalid-name tf.debugging.Assert, lambda condition, data, summarize=None, name=None: None) assert_equal = utils.copy_docstring(tf.debugging.assert_equal, _assert_equal) assert_greater = utils.copy_docstring(tf.debugging.assert_greater, _assert_greater) assert_less = utils.copy_docstring(tf.debugging.assert_less, _assert_less) assert_rank = utils.copy_docstring(tf.debugging.assert_rank, _assert_rank) assert_scalar = utils.copy_docstring(tf.debugging.assert_scalar, _assert_scalar) assert_greater_equal = utils.copy_docstring(tf.debugging.assert_greater_equal,
else: maxval = dtype(1) if maxval is None else maxval shape = _bcast_shape(shape, [minval, maxval]) # We must match ranks, as lax.max refuses to broadcast different-rank args. minval = minval + np.zeros([1] * final_rank, dtype=dtype) maxval = maxval + np.zeros([1] * final_rank, dtype=dtype) return jaxrand.uniform(key=seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval) # --- Begin Public Functions -------------------------------------------------- categorical = utils.copy_docstring( tf.random.categorical, _categorical_jax if JAX_MODE else _categorical) gamma = utils.copy_docstring(tf.random.gamma, _gamma_jax if JAX_MODE else _gamma) normal = utils.copy_docstring(tf.random.normal, _normal_jax if JAX_MODE else _normal) poisson = utils.copy_docstring(tf.random.poisson, _poisson_jax if JAX_MODE else _poisson) shuffle = utils.copy_docstring(tf.random.shuffle, _shuffle_jax if JAX_MODE else _shuffle) uniform = utils.copy_docstring(tf.random.uniform, _uniform_jax if JAX_MODE else _uniform)
def _rfftn(x, s, axes): x = ops.convert_to_tensor(x) complex_dtype = np.result_type(np.complex64, x.dtype) return np.fft.rfftn(x, s=s, axes=axes).astype(complex_dtype) def _irfftn(x, s, axes): x = ops.convert_to_tensor(x) float_dtype = np.finfo(x.dtype).dtype return np.fft.irfftn(x, s=s, axes=axes).astype(float_dtype) fft = utils.copy_docstring( 'tf.signal.fft', lambda input, name=None: _fftn(input, axes=[-1])) fft2d = utils.copy_docstring( 'tf.signal.fft2d', lambda input, name=None: _fftn(input, axes=[-2, -1])) fft3d = utils.copy_docstring( 'tf.signal.fft3d', lambda input, name=None: _fftn(input, axes=[-3, -2, -1])) ifft = utils.copy_docstring( 'tf.signal.ifft', lambda input, name=None: _ifftn(input, axes=[-1])) ifft2d = utils.copy_docstring(
x = np.array(x) if len(x.shape) != 1: raise tf.errors.InvalidArgumentError('unique expects a 1D vector.') y, idx = np.unique(x, return_index=True, return_inverse=False, return_counts=False, axis=None) idx = idx.astype(utils.numpy_dtype(out_idx)) return _UniqueOutput(y=y, idx=idx) # --- Begin Public Functions -------------------------------------------------- argsort = utils.copy_docstring( tf.argsort, _argsort) sort = utils.copy_docstring( tf.sort, _sort) tensor_scatter_nd_add = utils.copy_docstring( tf.tensor_scatter_nd_add, _tensor_scatter_nd_add) tensor_scatter_nd_sub = utils.copy_docstring( tf.tensor_scatter_nd_sub, _tensor_scatter_nd_sub) tensor_scatter_nd_update = utils.copy_docstring(
if maximum_iterations is None: def override_body_fn(args): return body(*args) def override_cond_fn(args): return cond(*args) return lax.while_loop(override_cond_fn, override_body_fn, loop_vars) else: # Use else to avoid linter saying these functions are already defined. def override_body_fn(args): i, args = args return i + 1, body(*args) def override_cond_fn(args): i, args = args return cond(*args) & (i < maximum_iterations) return lax.while_loop( override_cond_fn, override_body_fn, (np.array(0), loop_vars))[1] # --- Begin Public Functions -------------------------------------------------- cond = utils.copy_docstring( tf.cond, _cond_jax if JAX_MODE else _cond) no_op = utils.copy_docstring( tf.no_op, _no_op) while_loop = utils.copy_docstring( tf.while_loop, _while_loop_jax if JAX_MODE else _while_loop)
def _complex(real, imag, name=None): # pylint: disable=unused-argument dtype = utils.common_dtype([real, imag], dtype_hint=float32) real = np.array(real, dtype=dtype) imag = np.array(imag, dtype=dtype) if as_dtype(dtype) == float32: complex_dtype = complex64 else: complex_dtype = complex128 return real + imag * complex_dtype(1j) # --- Begin Public Functions -------------------------------------------------- as_dtype = utils.copy_docstring( 'tf.as_dtype', lambda type_value: np.dtype( # pylint: disable=g-long-lambda type_value.name if hasattr(type_value, 'name') else type_value).type) real_dtype = lambda dtype: np.real(np.zeros( (0, ), dtype=as_dtype(dtype))).dtype bool = np.bool_ # pylint: disable=redefined-builtin complex = utils.copy_docstring('tf.complex', _complex) # pylint: disable=redefined-builtin complex128 = np.complex128 complex64 = np.complex64 double = np.double
class _SingleReplicaContext(object): """Dummy replica context for numpy.""" @property def replica_id_in_sync_group(self): if JAX_MODE: raise NotImplementedError return 0 @property def num_replicas_in_sync(self): if JAX_MODE: raise NotImplementedError return 1 # --- Begin Public Functions -------------------------------------------------- compat = collections.namedtuple('compat', 'dimension_value')(dimension_value) distribute = collections.namedtuple( 'distribute', 'get_replica_context')(_SingleReplicaContext) function = utils.copy_docstring('tf.function', _function) eye = linalg.eye matmul = linalg.matmul del collections, utils
# because logsumexp is often used. m = _max_mask_non_finite(input_tensor, axis=axis, keepdims=True) y = input_tensor - m y = np.exp(y, out=y) return m + np.log(np.sum(y, axis=_astuple(axis), keepdims=keepdims)) def _top_k(input, k=1, sorted=True, name=None): # pylint: disable=unused-argument,redefined-builtin raise NotImplementedError # --- Begin Public Functions -------------------------------------------------- abs = utils.copy_docstring( # pylint: disable=redefined-builtin tf.math.abs, lambda x, name=None: np.abs(x)) accumulate_n = utils.copy_docstring( tf.math.accumulate_n, lambda inputs, shape=None, tensor_dtype=None, name=None: ( # pylint: disable=g-long-lambda sum(map(np.array, inputs)).astype(utils.numpy_dtype(tensor_dtype)))) acos = utils.copy_docstring( tf.math.acos, lambda x, name=None: np.arccos(x)) acosh = utils.copy_docstring( tf.math.acosh, lambda x, name=None: np.arccosh(x))
if JAX_MODE: min_z = (minvals - means) / stddevs max_z = (maxvals - means) / stddevs min_z = _right_expand(min_z, shape) max_z = _right_expand(max_z, shape) means = _right_expand(means, shape) stddevs = _right_expand(stddevs, shape) z = random.truncated_normal(seed, lower=min_z, upper=max_z, shape=shape, dtype=dtype) return z * stddevs + means raise NotImplementedError parameterized_truncated_normal = utils.copy_docstring( 'random_ops.parameterized_truncated_normal', _parameterized_truncated_normal) def _prevent_gradient(input, message='', name=None): # pylint: disable=unused-argument,redefined-builtin raise NotImplementedError prevent_gradient = utils.copy_docstring('array_ops.prevent_gradient', _prevent_gradient)
# See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Experimental Numpy backend.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import tensorflow as tf from tensorflow_probability.python.internal.backend.numpy import _utils as utils __all__ = [ 'Assert', 'check_numerics', ] Assert = utils.copy_docstring( # pylint: disable=invalid-name tf.debugging.Assert, lambda condition, data, summarize=None, name=None: None) check_numerics = utils.copy_docstring( tf.debugging.check_numerics, lambda x, *_, **__: x )
x = convert_to_tensor(input) if hasattr(shape, 'as_list'): shape = shape.as_list() if shape is None or any(s is None for s in shape): return x return np.reshape(x, shape) # --- Begin Public Functions -------------------------------------------------- matrix_determinant = linalg_impl.det matrix_solve = linalg_impl.solve colocate_with = utils.copy_docstring( tf1.colocate_with, _dummy_scope) control_flow_v2_enabled = utils.copy_docstring( tf1.control_flow_v2_enabled, lambda: True) get_variable = utils.copy_docstring( tf1.get_variable, _get_variable) get_variable_scope = utils.copy_docstring( tf1.get_variable_scope, lambda: variable_scope(name_or_scope=None)) placeholder_with_default = utils.copy_docstring(
nbatch = int(np.prod(matrix.shape[:-2])) flat_mat = matrix.reshape(nbatch, dim, dim) flat_rhs = rhs.reshape(nbatch, dim, rhs.shape[-1]) result = np.empty(flat_rhs.shape) if np.size(result): # ValueError: On entry to STRTRS parameter number 7 had an illegal value. for i, (mat, rh) in enumerate(zip(flat_mat, flat_rhs)): result[i] = scipy_linalg.solve_triangular( mat, rh, lower=lower, trans='C' if adjoint else 'N') return result.reshape(*rhs.shape) # --- Begin Public Functions -------------------------------------------------- adjoint = utils.copy_docstring( tf.linalg.adjoint, lambda matrix, name=None: _matrix_transpose(matrix, conjugate=True)) band_part = utils.copy_docstring(tf.linalg.band_part, _band_part) cholesky = utils.copy_docstring( tf.linalg.cholesky, lambda input, name=None: np.linalg.cholesky(input)) cholesky_solve = utils.copy_docstring(tf.linalg.cholesky_solve, _cholesky_solve) det = utils.copy_docstring(tf.linalg.det, lambda input, name=None: np.linalg.det(input)) diag = utils.copy_docstring(tf.linalg.diag, _diag)
logits, axis=-1, name=None): """Softmax cross entropy with logits.""" cost = -np.sum( np.where(labels == 0, np.zeros_like(labels), labels * (logits - reduce_logsumexp( logits, axis=axis, keepdims=True))), axis=axis) return cost.astype(logits.dtype) # --- Begin Public Functions -------------------------------------------------- l2_normalize = utils.copy_docstring( 'tf.nn.l2_normalize', l2_normalize) def _moments(x, axes, shift=None, keepdims=False, name=None): # pylint: disable=unused-argument # NOTE: If x.dtype is float16, we may want to compute in float32. mean = reduce_mean(x, axis=axes, keepdims=True) # NOTE: The gradient backpropagated to the mean from the variance calcuation # is zero, so we can safely use `stop_gradient(mean)` for efficiency. variance = reduce_mean(squared_difference(x, stop_gradient(mean)), axis=axes, keepdims=keepdims) if not keepdims: mean = numpy_array.squeeze(mean, axes) return (mean, variance) moments = utils.copy_docstring(
nbatch = int(np.prod(matrix.shape[:-2])) flat_mat = matrix.reshape(nbatch, dim, dim) flat_rhs = rhs.reshape(nbatch, dim, rhs.shape[-1]) result = np.empty(flat_rhs.shape) if np.size(result): # ValueError: On entry to STRTRS parameter number 7 had an illegal value. for i, (mat, rh) in enumerate(zip(flat_mat, flat_rhs)): result[i] = scipy_linalg.solve_triangular(mat, rh, lower=lower, trans='C' if adjoint else 'N') return result.reshape(*rhs.shape) # --- Begin Public Functions -------------------------------------------------- adjoint = utils.copy_docstring( 'tf.linalg.adjoint', lambda matrix, name=None: _matrix_transpose(matrix, conjugate=True)) band_part = utils.copy_docstring( 'tf.linalg.band_part', _band_part) cholesky = utils.copy_docstring( 'tf.linalg.cholesky', lambda input, name=None: np.linalg.cholesky(input)) cholesky_solve = utils.copy_docstring( 'tf.linalg.cholesky_solve', _cholesky_solve) det = utils.copy_docstring(
if x is None and y is None: return np.where(condition) dtype = utils.common_dtype([x, y]) x = convert_to_tensor(x, dtype=dtype) y = convert_to_tensor(y, dtype=dtype) while condition.ndim < max(x.ndim, y.ndim): condition = condition[..., np.newaxis] # add 1 dims until rank matches return np.where(condition, x, y) # --- Begin Public Functions -------------------------------------------------- matrix_determinant = linalg_impl.det matrix_solve = linalg_impl.solve colocate_with = utils.copy_docstring('tf1.colocate_with', _dummy_scope) control_flow_v2_enabled = utils.copy_docstring('tf1.control_flow_v2_enabled', lambda: True) enable_control_flow_v2 = utils.copy_docstring('tf1.enable_control_flow_v2', lambda: None) get_variable = utils.copy_docstring('tf1.get_variable', _get_variable) get_variable_scope = utils.copy_docstring( 'tf1.get_variable_scope', lambda: variable_scope(name_or_scope=None)) placeholder_with_default = utils.copy_docstring('tf1.placeholder_with_default', _placeholder_with_default)
cost = -np.sum(np.where( labels == 0, np.zeros_like(labels), labels * (logits - reduce_logsumexp(logits, axis=-1, keepdims=True))), axis=-1) cost = np.reshape(cost, labels_shape) return cost def _softmax_cross_entropy_with_logits(labels, logits, name=None): # pylint: disable=unused-argument raise NotImplementedError # --- Begin Public Functions -------------------------------------------------- l2_normalize = utils.copy_docstring(tf.nn.l2_normalize, l2_normalize) def _moments(x, axes, shift=None, keepdims=False, name=None): # pylint: disable=unused-argument # NOTE: If x.dtype is float16, we may want to compute in float32. mean = reduce_mean(x, axis=axes, keepdims=True) # NOTE: The gradient backpropagated to the mean from the variance calcuation # is zero, so we can safely use `stop_gradient(mean)` for efficiency. variance = reduce_mean(squared_difference(x, stop_gradient(mean)), axis=axes, keepdims=keepdims) if not keepdims: mean = numpy_array.squeeze(mean, axes) return (mean, variance)
def _assert_rank_at_least(x, rank, message=None, name=None): del name if len(x.shape) < rank: raise ValueError( 'Expected rank at least {} but got shape {} {}'.format( rank, x.shape, message or '')) def _assert_rank_in(*_, **__): # pylint: disable=unused-argument pass # --- Begin Public Functions -------------------------------------------------- Assert = utils.copy_docstring( # pylint: disable=invalid-name 'tf.debugging.Assert', lambda condition, data, summarize=None, name=None: assert_equal( # pylint: disable=g-long-lambda True, condition, message=data)) assert_all_finite = utils.copy_docstring('tf.debugging.assert_all_finite', _assert_all_finite) assert_equal = utils.copy_docstring('tf.debugging.assert_equal', _assert_equal) assert_greater = utils.copy_docstring('tf.debugging.assert_greater', _assert_greater) assert_less = utils.copy_docstring('tf.debugging.assert_less', _assert_less) assert_rank = utils.copy_docstring('tf.debugging.assert_rank', _assert_rank) assert_scalar = utils.copy_docstring('tf.debugging.assert_scalar',
def _argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None): # pylint: disable=unused-argument """Numpy implementation of `tf.argsort`.""" if direction == 'ASCENDING': pass elif direction == 'DESCENDING': values = np.negative(values) else: raise ValueError('Unrecognized direction: {}.'.format(direction)) return np.argsort(values, axis, kind='stable' if stable else 'quicksort') def _sort(values, axis=-1, direction='ASCENDING', stable=False, name=None): # pylint: disable=unused-argument """Numpy implementation of `tf.sort`.""" if direction == 'ASCENDING': pass elif direction == 'DESCENDING': values = np.negative(values) else: raise ValueError('Unrecognized direction: {}.'.format(direction)) result = np.sort(values, axis, kind='stable' if stable else 'quicksort') if direction == 'DESCENDING': return np.negative(result) return result # --- Begin Public Functions -------------------------------------------------- argsort = utils.copy_docstring(tf.argsort, _argsort) sort = utils.copy_docstring(tf.sort, _sort)
from tensorflow_probability.python.internal.backend.numpy import _utils as utils __all__ = [ 'MatrixDiagPartV2', ] JAX_MODE = False def _matrix_diag_part_v2(input, k, padding_value, name=None): # pylint: disable=redefined-builtin,unused-argument """Implements tf.raw_ops.MatrixDiagPartV2, for scalar k.""" if np.array(k).ndim > 0: raise NotImplementedError shp = np.shape(input) if JAX_MODE: if len(shp) > 2: from jax import vmap # pylint: disable=g-import-not-at-top return vmap(_matrix_diag_part_v2, (0, None, None))(input, k, padding_value) return np.diag(input, k=k) input = np.reshape(input, (-1, shp[-2], shp[-1])) output = np.array([np.diag(arr, k=k) for arr in input]) return output.reshape(*(shp[:-2] + output.shape[1:])) MatrixDiagPartV2 = utils.copy_docstring( # pylint: disable=invalid-name 'tf.raw_ops.MatrixDiagPartV2', _matrix_diag_part_v2)
out.append(arg) out = nest.map_structure(lambda *x: np.stack(x, axis=0), *out) if prepend is not None: out = nest.map_structure( lambda p, o: np.concatenate([p[np.newaxis], o], axis=0), prepend, out) ordering = (lambda x: x[::-1]) if reverse else (lambda x: x) return nest.map_structure(ordering, out, expand_composites=True) # --- Begin Public Functions -------------------------------------------------- foldl = utils.copy_docstring( 'tf.foldl', _foldl_jax if JAX_MODE else _foldl) map_fn = utils.copy_docstring( 'tf.map_fn', _map_fn) vectorized_map = utils.copy_docstring( 'tf.vectorized_map', _vectorized_map) def pfor(fn, n): if JAX_MODE: import jax # pylint: disable=g-import-not-at-top return jax.vmap(fn)(np.arange(n))
# 'dynamic_stitch', # 'map_fn', # 'scan', ] def _no_op(_): pass def _while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, # pylint: disable=unused-argument back_prop=True, swap_memory=False, # pylint: disable=unused-argument maximum_iterations=None, name=None): # pylint: disable=unused-argument i = 0 while (cond(*loop_vars) and (maximum_iterations is None or i < maximum_iterations)): loop_vars = body(*loop_vars) i += 1 return loop_vars # --- Begin Public Functions -------------------------------------------------- no_op = utils.copy_docstring( tf.no_op, _no_op) while_loop = utils.copy_docstring( tf.while_loop, _while_loop)
for i, z in enumerate(arg): out[i].append(z) if prepend is not None: out = [pre + list(o) for (pre, o) in zip(prepend, out)] ordering = (lambda x: x[::-1]) if reverse else (lambda x: x) return nest.pack_sequence_as( initializer, [ordering(np.array(o)) for o in out]) # --- Begin Public Functions -------------------------------------------------- map_fn = utils.copy_docstring( 'tf.map_fn', _map_fn) vectorized_map = utils.copy_docstring( 'tf.vectorized_map', _vectorized_map) def pfor(fn, n): if JAX_MODE: import jax # pylint: disable=g-import-not-at-top return jax.vmap(fn)(np.arange(n)) outs = [fn(i) for i in range(n)] flat_outs = [nest.flatten(o) for o in outs] return nest.pack_sequence_as( outs[0], [np.array(o) for o in zip(*flat_outs)])
dtype=dtype) else: maxval = dtype(1) if maxval is None else maxval shape = _bcast_shape(shape, [minval, maxval]) # We must match ranks, as lax.max refuses to broadcast different-rank args. minval = minval + np.zeros([1] * final_rank, dtype=dtype) maxval = maxval + np.zeros([1] * final_rank, dtype=dtype) return jaxrand.uniform(key=seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval) # --- Begin Public Functions -------------------------------------------------- stateless_categorical = utils.copy_docstring( tf.random.stateless_categorical, _categorical_jax if JAX_MODE else _categorical) stateless_gamma = utils.copy_docstring( tf.random.stateless_gamma, _gamma_jax if JAX_MODE else _gamma) # TODO(b/147874898): Delete this method. def gamma(shape, alpha, beta=None, dtype=tf.float32, seed=None, name=None): """Handles the difference in shape parameter interpretation.""" # While we still have usages of tf.random.gamma and tf.random.stateless_gamma, # we must handle the different interpretation of the shape argument # between the two. `tf.random.gamma` interprets shape as a prefix. # `tf.random.stateless_gamma` interprets shape as the full output shape, # including as suffix the broadcast of alpha and beta shapes.
nbatch = int(np.prod(matrix.shape[:-2])) flat_mat = matrix.reshape(nbatch, dim, dim) flat_rhs = rhs.reshape(nbatch, dim, rhs.shape[-1]) result = np.empty(flat_rhs.shape) if np.size(result): # ValueError: On entry to STRTRS parameter number 7 had an illegal value. for i, (mat, rh) in enumerate(zip(flat_mat, flat_rhs)): result[i] = scipy_linalg.solve_triangular( mat, rh, lower=lower, trans='C' if adjoint else 'N') return result.reshape(*rhs.shape) # --- Begin Public Functions -------------------------------------------------- adjoint = utils.copy_docstring( 'tf.linalg.adjoint', lambda matrix, name=None: _matrix_transpose(matrix, conjugate=True)) band_part = utils.copy_docstring('tf.linalg.band_part', _band_part) cholesky = utils.copy_docstring( 'tf.linalg.cholesky', lambda input, name=None: np.linalg.cholesky(input)) cholesky_solve = utils.copy_docstring('tf.linalg.cholesky_solve', _cholesky_solve) det = utils.copy_docstring('tf.linalg.det', lambda input, name=None: np.linalg.det(input)) diag = utils.copy_docstring('tf.linalg.diag', _diag)
# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Numpy bitwise ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from tensorflow_probability.python.internal.backend.numpy import _utils as utils __all__ = [ 'bitwise_xor', 'left_shift', ] bitwise_xor = utils.copy_docstring( 'tf.bitwise.bitwise_xor', lambda x, y, name=None: np.bitwise_xor(x, y)) # pylint: disable=unused-argument left_shift = utils.copy_docstring('tf.bitwise.left_shift', lambda x, y, name=None: np.left_shift(x, y)) # pylint: disable=unused-argument
sources, output_gradients=None, # pylint: disable=unused-argument unconnected_gradients=UnconnectedGradients.NONE): # pylint: disable=unused-argument return sources def batch_jacobian( self, target, source, # pylint: disable=unused-argument unconnected_gradients=UnconnectedGradients.NONE, # pylint: disable=unused-argument parallel_iterations=None, experimental_use_pfor=True): # pylint: disable=unused-argument return source broadcast_dynamic_shape = utils.copy_docstring(tf.broadcast_dynamic_shape, _broadcast_static_shape) broadcast_static_shape = utils.copy_docstring(tf.broadcast_static_shape, _broadcast_static_shape) broadcast_to = utils.copy_docstring( tf.broadcast_to, lambda input, shape, name=None: np.broadcast_to(input, shape)) cast = utils.copy_docstring( tf.cast, lambda x, dtype, name=None: np.array(x).astype(utils.numpy_dtype(dtype))) clip_by_value = utils.copy_docstring( tf.clip_by_value, lambda t, clip_value_min, clip_value_max, name=None: # pylint: disable=g-long-lambda
return (-np.sort(-input, axis=-1)[..., :k], (n - (np.argsort( input[..., ::-1], kind='stable', axis=-1)[..., ::-1]))[..., :k]) def _unsorted_segment_sum(data, segment_ids, num_segments, name=None): del name if not JAX_MODE: raise NotImplementedError sums = np.zeros(num_segments) return jax.ops.index_add(sums, jax.ops.index[segment_ids], data) # --- Begin Public Functions -------------------------------------------------- abs = utils.copy_docstring( # pylint: disable=redefined-builtin 'tf.math.abs', lambda x, name=None: np.abs(x)) accumulate_n = utils.copy_docstring( 'tf.math.accumulate_n', lambda inputs, shape=None, tensor_dtype=None, name=None: ( # pylint: disable=g-long-lambda sum(map(np.array, inputs)).astype(utils.numpy_dtype(tensor_dtype)))) acos = utils.copy_docstring('tf.math.acos', lambda x, name=None: np.arccos(x)) acosh = utils.copy_docstring('tf.math.acosh', lambda x, name=None: np.arccosh(x)) add = utils.copy_docstring('tf.math.add', lambda x, y, name=None: np.add(x, y)) add_n = utils.copy_docstring(