def _convert_to_tensor(value, dtype=None, dtype_hint=None, name=None): # pylint: disable=unused-argument """Emulates tf.convert_to_tensor.""" assert not tf.is_tensor(value), value if isinstance(value, np.ndarray): if dtype is not None: dtype = utils.numpy_dtype(dtype) # if np.result_type(value, dtype) != dtype: # raise ValueError('Expected dtype {} but got {} with dtype {}.'.format( # dtype, value, value.dtype)) return value.astype(dtype) return value if isinstance(value, TensorShape): value = [int(d) for d in value.as_list()] if dtype is None and dtype_hint is not None: dtype_hint = utils.numpy_dtype(dtype_hint) value = np.array(value) if np.size(value): # Match TF behavior, which won't downcast e.g. float to int. if np.issubdtype(value.dtype, np.complexfloating): if not np.issubdtype(dtype_hint, np.complexfloating): return value if np.issubdtype(value.dtype, np.floating): if not np.issubdtype(dtype_hint, np.floating): return value if np.issubdtype(value.dtype, np.integer): if not np.issubdtype(dtype_hint, np.integer): return value return value.astype(dtype_hint) return np.array(value, dtype=utils.numpy_dtype(dtype or dtype_hint))
def _bincount( arr, weights=None, minlength=None, maxlength=None, # pylint: disable=unused-argument dtype=tf.int32, name=None): # pylint: disable=unused-argument return np.bincount(arr, weights, minlength).astype(utils.numpy_dtype(dtype))
def _eye(num_rows, num_columns=None, batch_shape=None, dtype=tf.float32, name=None): # pylint: disable=unused-argument dt = utils.numpy_dtype(dtype) x = np.eye(num_rows, num_columns).astype(dt) if batch_shape is not None: x = x * np.ones(tuple(batch_shape) + (1, 1)).astype(dt) return x
def _categorical(logits, num_samples, dtype=None, seed=None, name=None): # pylint: disable=unused-argument rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff) dtype = utils.numpy_dtype(dtype or np.int64) if not hasattr(logits, 'shape'): logits = np.array(logits, np.float32) probs = _softmax(logits) n = logits.shape[-1] return np.apply_along_axis(lambda p: rng.choice(n, p=p, size=num_samples), 1, probs)
def _categorical_jax(logits, num_samples, dtype=None, seed=None, name=None): # pylint: disable=unused-argument dtype = utils.numpy_dtype(dtype or np.int64) if not hasattr(logits, 'shape') or not hasattr(logits, 'dtype'): logits = np.array(logits, np.float32) import jax.random as jaxrand # pylint: disable=g-import-not-at-top if seed is None: raise ValueError('Must provide PRNGKey to sample in JAX.') z = jaxrand.gumbel(key=seed, shape=logits.shape + (num_samples, ), dtype=logits.dtype) return np.argmax(np.expand_dims(logits, -1) + z, axis=-2).astype(dtype)
def _lu(input, output_idx_type=tf.int32, name=None): # pylint: disable=redefined-builtin """Returns Lu(lu, p), as TF does.""" del name if JAX_MODE: # But JAX uses XLA, which can do a batched factorization. lu_out, pivots = scipy_linalg.lu_factor(input) from jax import lax_linalg # pylint: disable=g-import-not-at-top return Lu( lu_out, lax_linalg.lu_pivots_to_permutation(pivots, lu_out.shape[-1])) # Scipy can't batch, so we must do so manually. nbatch = int(np.prod(input.shape[:-2])) dim = input.shape[-1] flat_mat = input.reshape(nbatch, dim, dim) flat_lu = np.empty((nbatch, dim, dim), dtype=input.dtype) flat_piv = np.empty((nbatch, dim), dtype=utils.numpy_dtype(output_idx_type)) if np.size(flat_lu): # Avoid non-empty batches of empty matrices. for i, mat in enumerate(flat_mat): lu_out, pivots = scipy_linalg.lu_factor(mat) flat_lu[i] = lu_out flat_piv[i] = _lu_pivot_to_permutation(pivots, flat_lu.shape[-1]) return Lu(flat_lu.reshape(*input.shape), flat_piv.reshape(*input.shape[:-1]))
def _constant(value, dtype=None, shape=None, name='Const'): # pylint: disable=unused-argument x = np.array(value, dtype=None if dtype is None else utils.numpy_dtype(dtype)) if shape is None: return x return np.reshape(x, shape)
_broadcast_static_shape) broadcast_static_shape = utils.copy_docstring(tf.broadcast_static_shape, _broadcast_static_shape) broadcast_static_shape_as_tensorshape = utils.copy_docstring( tf.broadcast_static_shape, functools.partial(_broadcast_static_shape, as_tensorshape=True)) broadcast_to = utils.copy_docstring( tf.broadcast_to, lambda input, shape, name=None: onp.broadcast_to(input, shape)) cast = utils.copy_docstring( tf.cast, lambda x, dtype, name=None: np.array(x).astype(utils.numpy_dtype(dtype))) clip_by_value = utils.copy_docstring( tf.clip_by_value, lambda t, clip_value_min, clip_value_max, name=None: # pylint: disable=g-long-lambda np.clip(t, clip_value_min, clip_value_max)) constant = utils.copy_docstring(tf.constant, _constant) control_dependencies = utils.copy_docstring(tf.control_dependencies, _control_dependencies) convert_to_tensor = utils.copy_docstring(tf.convert_to_tensor, _convert_to_tensor) custom_gradient = utils.copy_docstring(tf.custom_gradient, lambda f: f)
def _top_k(input, k=1, sorted=True, name=None): # pylint: disable=unused-argument,redefined-builtin raise NotImplementedError # --- Begin Public Functions -------------------------------------------------- abs = utils.copy_docstring( # pylint: disable=redefined-builtin tf.math.abs, lambda x, name=None: np.abs(x)) accumulate_n = utils.copy_docstring( tf.math.accumulate_n, lambda inputs, shape=None, tensor_dtype=None, name=None: ( # pylint: disable=g-long-lambda sum(map(np.array, inputs)).astype(utils.numpy_dtype(tensor_dtype)))) acos = utils.copy_docstring(tf.math.acos, lambda x, name=None: np.arccos(x)) acosh = utils.copy_docstring(tf.math.acosh, lambda x, name=None: np.arccosh(x)) add = utils.copy_docstring(tf.math.add, lambda x, y, name=None: np.add(x, y)) add_n = utils.copy_docstring( tf.math.add_n, lambda inputs, name=None: sum(map(np.array, inputs))) angle = utils.copy_docstring(tf.math.angle, lambda input, name=None: np.angle(input)) argmax = utils.copy_docstring( tf.math.argmax,
def _zeros_like(input, dtype=None, name=None): # pylint: disable=redefined-builtin s = _shape(input) if isinstance(s, (np.ndarray, onp.generic)): return np.zeros(s, utils.numpy_dtype(dtype or input.dtype)) return tf.zeros(s, dtype or s.dtype, name)
def _size(input, out_type=tf.int32, name=None): # pylint: disable=redefined-builtin, unused-argument return np.prod(np.array(input).shape).astype(utils.numpy_dtype(out_type))
linspace = utils.copy_docstring( tf.linspace, lambda start, stop, num, name=None: ( # pylint: disable=g-long-lambda np.linspace(start, stop, num).astype(np.array(start).dtype))) meshgrid = utils.copy_docstring(tf.meshgrid, np.meshgrid) norm = utils.copy_docstring(tf.norm, norm) one_hot = utils.copy_docstring(tf.one_hot, _one_hot) ones = utils.copy_docstring( tf.ones, lambda shape, dtype=tf.float32, name=None: np.ones( # pylint: disable=g-long-lambda shape, utils.numpy_dtype(dtype))) ones_like = utils.copy_docstring(tf.ones_like, _ones_like) pad = utils.copy_docstring(tf.pad, _pad) range = utils.copy_docstring( # pylint: disable=redefined-builtin tf.range, lambda start, limit=None, delta=1, dtype=None, name='range': ( # pylint: disable=g-long-lambda np.arange(start, limit, delta, utils.numpy_dtype(dtype)))) rank = utils.copy_docstring(tf.rank, lambda input, name=None: np.array(input).ndim) # pylint: disable=redefined-builtin,g-long-lambda reshape = utils.copy_docstring( tf.reshape, lambda tensor, shape, name=None: np.reshape(tensor, shape))