def _bincount( arr, weights=None, minlength=None, maxlength=None, # pylint: disable=unused-argument dtype=np.int32, name=None): # pylint: disable=unused-argument """Counts number of occurences of each value in `arr`.""" # TODO(https://github.com/google/jax/issues/5719): Use np.bincount directly? if not JAX_MODE: return np.bincount(arr, weights, minlength).astype(utils.numpy_dtype(dtype)) dtype = utils.numpy_dtype(dtype) num_buckets = (np.max(arr) + 1) if np.size(arr) else 0 if minlength is not None and maxlength is not None and minlength == maxlength: # In the case where we can use minlength directly, this helps avoids the # use of an abstract value, which prevents JAX JIT. num_buckets = minlength else: if minlength is not None: num_buckets = np.maximum(num_buckets, minlength) if maxlength is not None: num_buckets = np.minimum(num_buckets, maxlength) one_hots = one_hot(arr, num_buckets) # Reduce over every dimension except the last one. axes = tuple(range(0, one_hots.ndim - 1)) if weights is not None: return np.sum(one_hots * weights[..., np.newaxis], axis=axes).astype(dtype) return np.sum(one_hots, axis=axes).astype(dtype)
def _convert_to_tensor(value, dtype=None, dtype_hint=None, name=None): # pylint: disable=unused-argument """Emulates tf.convert_to_tensor.""" dtype = utils.numpy_dtype(dtype) dtype_hint = utils.numpy_dtype(dtype_hint) if is_tensor(value) and not isinstance(value, Variable): if dtype is not None: # In NumPy mode, we are lenient on the dtype compatibility check because # some codepaths rely on flexible conversion from int/float64 to 32. if JAX_MODE and value.dtype != dtype: raise TypeError(('Tensor conversion requested dtype {} for array with ' 'dtype {}: {}').format(dtype, value.dtype, value)) return value.astype(dtype) return value conversion_func = tensor_conversion_registry.get(type(value), _default_convert_to_tensor) ret = None if dtype is None and dtype_hint is not None: try: ret = conversion_func(value, dtype=dtype_hint) except (TypeError, ValueError): pass if ret is None: ret = conversion_func(value, dtype=dtype) return ret
def _bincount( arr, weights=None, minlength=None, maxlength=None, # pylint: disable=unused-argument dtype=np.int32, name=None): # pylint: disable=unused-argument """Counts number of occurences of each value in `arr`.""" if not JAX_MODE: return np.bincount(arr, weights, minlength).astype(utils.numpy_dtype(dtype)) dtype = utils.numpy_dtype(dtype) num_buckets = np.max(arr) + 1 if minlength is not None: num_buckets = np.maximum(num_buckets, minlength) if maxlength is not None: num_buckets = np.minimum(num_buckets, maxlength) one_hots = one_hot(arr, num_buckets) # Reduce over every dimension except the last one. axes = tuple(range(0, one_hots.ndim - 1)) if weights is not None: return np.sum(one_hots * weights[..., np.newaxis], axis=axes).astype(dtype) return np.sum(one_hots, axis=axes).astype(dtype)
def _convert_to_tensor(value, dtype=None, dtype_hint=None, name=None): # pylint: disable=unused-argument """Emulates tf.convert_to_tensor.""" assert not tf.is_tensor(value), value if isinstance(value, np.ndarray): if dtype is not None: dtype = utils.numpy_dtype(dtype) # if np.result_type(value, dtype) != dtype: # raise ValueError('Expected dtype {} but got {} with dtype {}.'.format( # dtype, value, value.dtype)) return value.astype(dtype) return value if isinstance(value, TensorShape): value = [int(d) for d in value.as_list()] if dtype is None and dtype_hint is not None: dtype_hint = utils.numpy_dtype(dtype_hint) value = np.array(value) if np.size(value): # Match TF behavior, which won't downcast e.g. float to int. if np.issubdtype(value.dtype, np.complexfloating): if not np.issubdtype(dtype_hint, np.complexfloating): return value if np.issubdtype(value.dtype, np.floating): if not np.issubdtype(dtype_hint, np.floating): return value if np.issubdtype(value.dtype, np.integer): if not np.issubdtype(dtype_hint, np.integer): return value return value.astype(dtype_hint) return np.array(value, dtype=utils.numpy_dtype(dtype or dtype_hint))
def _convert_to_tensor(value, dtype=None, dtype_hint=None, name=None): # pylint: disable=unused-argument """Emulates tf.convert_to_tensor.""" assert not tf.is_tensor(value), value if is_tensor(value): if dtype is not None: dtype = utils.numpy_dtype(dtype) # if np.result_type(value, dtype) != dtype: # raise ValueError('Expected dtype {} but got {} with dtype {}.'.format( # dtype, value, value.dtype)) return value.astype(dtype) return value if isinstance(value, Dimension): value = _dimension_value(value) elif isinstance(value, TensorShape): value = value.as_list() # In JAX mode, onp.ndarray/onp.generic are not identified as Tensor's. # By default, use the dtype of the values passed in. elif hasattr(value, 'dtype'): if dtype is not None: dtype = utils.numpy_dtype(dtype) return np.array(value).astype(dtype) return np.array(value) if dtype is None and dtype_hint is not None: dtype_hint = utils.numpy_dtype(dtype_hint) value = np.array(value) if np.size(value): # Match TF behavior, which won't downcast e.g. float to int. if np.issubdtype(value.dtype, np.complexfloating): if not np.issubdtype(dtype_hint, np.complexfloating): return value if np.issubdtype(value.dtype, np.floating): if not (np.issubdtype(dtype_hint, np.floating) or np.issubdtype(dtype_hint, np.complexfloating)): return value if np.issubdtype(value.dtype, np.integer): if not (np.issubdtype(dtype_hint, np.integer) or np.issubdtype(dtype_hint, np.floating) or np.issubdtype(dtype_hint, np.complexfloating)): return value return value.astype(dtype_hint) np_value = np.array(value, dtype=utils.numpy_dtype(dtype or dtype_hint)) if np.issubdtype(np_value.dtype, np.object_): raise ValueError('Numpy `object`s cannot be converted to `Tensor`s.') # We have no hints. By default JAX (in x64 mode) and Numpy default to # {int64,float64} which does not match with TF's default. if dtype is None and dtype_hint is None: # If the integer doesn't fit in int32, return an int64. This matches TF. if isinstance(value, int): if value > onp.iinfo(onp.int32).max or value < onp.iinfo( onp.int32).min: return np.array(value, dtype=np.int64) if np.issubdtype(np_value.dtype, np.floating): return np_value.astype(np.float32) if np.issubdtype(np_value.dtype, np.integer): return np_value.astype(np.int32) return np_value
def _range(start, limit=None, delta=1, dtype=None, name='range'): # pylint: disable=unused-argument dtype = utils.numpy_dtype(dtype or utils.common_dtype([start], np.int32)) start = ops.convert_to_tensor(start, dtype=dtype) limit = None if limit is None else ops.convert_to_tensor(limit, dtype=dtype) delta = ops.convert_to_tensor(delta, dtype=dtype) return np.arange(start, limit, delta).astype(dtype)
def __init__(self, dtype, size=None, dynamic_size=None, clear_after_read=None, tensor_array_name=None, handle=None, flow=None, infer_shape=True, element_shape=None, colocate_with_first_write_call=True, data=None, name=None): self._dtype = utils.numpy_dtype(dtype) if data is None: if JAX_MODE and size is not None and element_shape is not None: data = np.empty((size, ) + element_shape, dtype=self._dtype) else: data = [None] * (0 if size is None else int(size)) self._data = data self._size = size self._dynamic_size = dynamic_size self._clear_after_read = clear_after_read self._tensor_array_name = tensor_array_name self._handle = handle self._flow = flow self._infer_shape = infer_shape self._element_shape = element_shape self._colocate_with_first_write_call = colocate_with_first_write_call self._name = name
def _eye(num_rows, num_columns=None, batch_shape=None, dtype=np.float32, name=None): # pylint: disable=unused-argument dt = utils.numpy_dtype(dtype) x = np.eye(num_rows, num_columns).astype(dt) if batch_shape is not None: x = x * np.ones(tuple(batch_shape) + (1, 1)).astype(dt) return x
def _one_hot( # pylint: disable=unused-argument indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None): """One hot.""" if on_value is None: on_value = 1 if off_value is None: off_value = 0 if dtype is None: dtype = utils.common_dtype([on_value, off_value], np.float32) else: dtype = utils.numpy_dtype(dtype) indices = np.array(indices) depth = np.array(depth) pred = abs(np.arange(depth, dtype=indices.dtype) - indices[..., np.newaxis]) > 0 y_out = np.where(pred, np.array(off_value, dtype), np.array(on_value, dtype)) if axis is not None: y_out = np.moveaxis(y_out, -1, axis) return y_out
def _categorical(logits, num_samples, dtype=None, seed=None, name=None): # pylint: disable=unused-argument rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff) dtype = utils.numpy_dtype(dtype or np.int64) if not hasattr(logits, 'shape'): logits = np.array(logits, np.float32) n = logits.shape[-1] return rng.choice(n, p=_softmax(logits), size=num_samples).astype(dtype)
def __array__(self, dtype=None): if dtype is not None: dtype = utils.numpy_dtype(dtype) return self.__wrapped__.__array__(dtype) # Passing in dtype=None to __array__ has differing behavior in numpy. # When an `np.ndarray` has `.__array__(None)` invoked, the array is casted # to `float64`. Thus we handle this case separately. return self.__wrapped__.__array__()
def _bincount( arr, weights=None, minlength=None, maxlength=None, # pylint: disable=unused-argument dtype=tf.int32, name=None): # pylint: disable=unused-argument return np.bincount(arr, weights, minlength).astype(utils.numpy_dtype(dtype))
def _histogram_fixed_width(values, value_range, nbins=100, dtype=np.int32, name=None): """Numpy implementation of `tf.histogram_fixed_width`.""" del name return np.histogram(values, bins=nbins, range=value_range)[0].astype(utils.numpy_dtype(dtype))
def _categorical_jax(logits, num_samples, dtype=None, seed=None, name=None): # pylint: disable=unused-argument dtype = utils.numpy_dtype(dtype or np.int64) if not hasattr(logits, 'shape') or not hasattr(logits, 'dtype'): logits = np.array(logits, np.float32) import jax.random as jaxrand # pylint: disable=g-import-not-at-top if seed is None: raise ValueError('Must provide PRNGKey to sample in JAX.') z = jaxrand.gumbel( key=seed, shape=logits.shape + (num_samples,), dtype=logits.dtype) return np.argmax(np.expand_dims(logits, -1) + z, axis=-2).astype(dtype)
def _binomial(shape, seed, counts, probs, output_dtype=np.int32, name=None): # pylint: disable=unused-argument rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff) invalid_count = (np.int64(counts) < 0) != (counts < 0) if np.any(invalid_count): raise ValueError('int64 overflow: {} -> {}'.format( counts[np.where(invalid_count)], np.int64(counts)[np.where(invalid_count)])) probs = np.where(counts > 0, probs, 0) samps = rng.binomial(np.int64(counts), np.float64(probs), shape) return samps.astype(utils.numpy_dtype(output_dtype))
def _eye(num_rows, num_columns=None, batch_shape=None, dtype=tf.float32, name=None): # pylint: disable=unused-argument dt = utils.numpy_dtype(dtype) x = np.eye(num_rows, num_columns).astype(dt) if batch_shape is not None: x = x * np.ones(np.concatenate([batch_shape, [1, 1]], axis=0)).astype(dt) return x
def _histogram_fixed_width_bins(values, value_range, nbins=100, dtype=np.int32, name=None): """Numpy implementation of `tf.histogram_fixed_width_bins`.""" del name nbins_float = np.array(nbins, values.dtype) scaled_values = truediv( values - value_range[0], value_range[1] - value_range[0]) indices = floor(nbins_float * scaled_values) indices = clip_by_value(indices, 0, nbins_float - 1).astype( utils.numpy_dtype(dtype)) return indices
def _unique(x, out_idx=tf.int32, name=None): # pylint: disable=unused-argument """Numpy implementation of `tf.unique`.""" x = np.array(x) if len(x.shape) != 1: raise tf.errors.InvalidArgumentError('unique expects a 1D vector.') y, idx = np.unique(x, return_index=True, return_inverse=False, return_counts=False, axis=None) idx = idx.astype(utils.numpy_dtype(out_idx)) return _UniqueOutput(y=y, idx=idx)
def _confusion_matrix( labels, predictions, num_classes=None, weights=None, dtype=np.int32, name=None): """Return confusion matrix between predictions and labels.""" del name if num_classes is None: num_classes = np.maximum(np.max(predictions), np.max(labels)) + 1 cmatrix = np.zeros([num_classes, num_classes], dtype=utils.numpy_dtype(dtype)) if weights is None: weights = 1 if not JAX_MODE: np.add.at(cmatrix, [labels, predictions], weights) return cmatrix return jax.ops.index_add(cmatrix, [labels, predictions], weights)
def _range(start, limit=None, delta=1, dtype=None, name='range'): # pylint: disable=unused-argument """Emulates tf.range.""" # Emulating dtype inference logic from tf.range dtype = utils.numpy_dtype(dtype) start = ops.convert_to_tensor(start, dtype=dtype) limit = None if limit is None else ops.convert_to_tensor(limit, dtype=dtype) delta = ops.convert_to_tensor(delta, dtype=dtype) if dtype is None: dtype_hierarchy = [np.int32, np.int64, np.float32, np.float64] inferred_dtype = max([arg.dtype for arg in [start, limit, delta] if arg is not None], key=dtype_hierarchy.index) else: inferred_dtype = dtype return np.arange(start, limit, delta).astype(inferred_dtype)
def __init__(self, initial_value=None, trainable=True, validate_shape=True, caching_device=None, name=None, variable_def=None, dtype=None, import_scope=None, constraint=None, shape=None): assert constraint is None v = convert_to_tensor(initial_value) if dtype is not None: v = v.astype(utils.numpy_dtype(dtype)) super(NumpyVariable, self).__init__(v) self.initializer = None
def _range(start, limit=None, delta=1, dtype=None, name='range'): # pylint: disable=unused-argument """Emulates tf.range.""" # Emulating dtype inference logic from tf.range dtype = utils.numpy_dtype(dtype) infer_dtype = lambda t: ops.convert_to_tensor(t, dtype=dtype).dtype # We must keep start, limit, and delta static np.array since they determine # the size of the result array, which JAX requires to be static. start = onp.array(start, dtype=infer_dtype(start)) limit = None if limit is None else onp.array(limit, dtype=infer_dtype(limit)) delta = onp.array(delta, dtype=infer_dtype(delta)) if dtype is None: dtype_hierarchy = [np.int32, np.int64, np.float32, np.float64] inferred_dtype = max([arg.dtype for arg in [start, limit, delta] if arg is not None], key=dtype_hierarchy.index) else: inferred_dtype = dtype return np.arange(start, limit, delta).astype(inferred_dtype)
def _lu(input, output_idx_type=tf.int32, name=None): # pylint: disable=redefined-builtin """Returns Lu(lu, p), as TF does.""" del name if JAX_MODE: # But JAX uses XLA, which can do a batched factorization. lu_out, pivots = scipy_linalg.lu_factor(input) from jax import lax_linalg # pylint: disable=g-import-not-at-top return Lu(lu_out, lax_linalg.lu_pivots_to_permutation(pivots, lu_out.shape[-1])) # Scipy can't batch, so we must do so manually. nbatch = int(np.prod(input.shape[:-2])) dim = input.shape[-1] flat_mat = input.reshape(nbatch, dim, dim) flat_lu = np.empty((nbatch, dim, dim), dtype=input.dtype) flat_piv = np.empty((nbatch, dim), dtype=utils.numpy_dtype(output_idx_type)) if np.size(flat_lu): # Avoid non-empty batches of empty matrices. for i, mat in enumerate(flat_mat): lu_out, pivots = scipy_linalg.lu_factor(mat) flat_lu[i] = lu_out flat_piv[i] = _lu_pivot_to_permutation(pivots, flat_lu.shape[-1]) return Lu(flat_lu.reshape(*input.shape), flat_piv.reshape(*input.shape[:-1]))
def __init__(self, dtype, size=None, dynamic_size=None, clear_after_read=None, tensor_array_name=None, handle=None, flow=None, infer_shape=True, element_shape=None, colocate_with_first_write_call=True, data=None, name=None): self._dtype = utils.numpy_dtype(dtype) if data is None: if JAX_MODE and size is not None and element_shape is not None: data = np.empty((size, ) + tuple(element_shape), dtype=self._dtype) # Can be useful for finding failure cases in JAX TensorArray-using code. # elif JAX_MODE: # raise ValueError( # 'Missing shape argument: size {} element_shape {}'.format( # size, element_shape)) else: data = [None] * (0 if size is None else int(size)) self._data = data self._size = size self._dynamic_size = dynamic_size self._clear_after_read = clear_after_read self._tensor_array_name = tensor_array_name self._handle = handle self._flow = flow self._infer_shape = infer_shape self._element_shape = element_shape self._colocate_with_first_write_call = colocate_with_first_write_call self._name = name
def __init__(self, dtype, size=None, dynamic_size=None, clear_after_read=None, tensor_array_name=None, handle=None, flow=None, infer_shape=True, element_shape=None, colocate_with_first_write_call=True, name=None): self._data = [None] * (size if size else 0) self._dtype = utils.numpy_dtype(dtype) self._size = size self._dynamic_size = dynamic_size self._clear_after_read = clear_after_read self._tensor_array_name = tensor_array_name self._handle = handle self._flow = flow self._infer_shape = infer_shape self._element_shape = element_shape self._colocate_with_first_write_call = colocate_with_first_write_call self._name = name
def _constant(value, dtype=None, shape=None, name='Const'): # pylint: disable=unused-argument x = np.array(value, dtype=None if dtype is None else utils.numpy_dtype(dtype)) if shape is None: return x return np.reshape(x, shape)
return source broadcast_dynamic_shape = utils.copy_docstring(tf.broadcast_dynamic_shape, _broadcast_static_shape) broadcast_static_shape = utils.copy_docstring(tf.broadcast_static_shape, _broadcast_static_shape) broadcast_to = utils.copy_docstring( tf.broadcast_to, lambda input, shape, name=None: np.broadcast_to(input, shape)) cast = utils.copy_docstring( tf.cast, lambda x, dtype, name=None: np.array(x).astype(utils.numpy_dtype(dtype))) clip_by_value = utils.copy_docstring( tf.clip_by_value, lambda t, clip_value_min, clip_value_max, name=None: # pylint: disable=g-long-lambda np.clip(t, clip_value_min, clip_value_max)) constant = utils.copy_docstring(tf.constant, _constant) control_dependencies = utils.copy_docstring(tf.control_dependencies, _control_dependencies) convert_to_tensor = utils.copy_docstring(tf.convert_to_tensor, _convert_to_tensor) custom_gradient = utils.copy_docstring(tf.custom_gradient, lambda f: f)
def _zeros_like(input, dtype=None, name=None): # pylint: disable=redefined-builtin,unused-argument return np.zeros_like(input, dtype=utils.numpy_dtype(dtype))
meshgrid = utils.copy_docstring( 'tf.meshgrid', np.meshgrid) norm = utils.copy_docstring( 'tf.norm', norm) one_hot = utils.copy_docstring( 'tf.one_hot', _one_hot) ones = utils.copy_docstring( 'tf.ones', lambda shape, dtype=np.float32, name=None: np.ones( # pylint: disable=g-long-lambda shape, utils.numpy_dtype(dtype))) ones_like = utils.copy_docstring( 'tf.ones_like', _ones_like) pad = utils.copy_docstring( 'tf.pad', _pad) range = utils.copy_docstring( # pylint: disable=redefined-builtin 'tf.range', _range) rank = utils.copy_docstring( 'tf.rank',
broadcast_to = utils.copy_docstring( 'tf.broadcast_to', lambda input, shape, name=None: np.broadcast_to(input, shape)) def _cast(x, dtype): x = np.asarray(x) if (np.issubdtype(x.dtype, np.complexfloating) and not np.issubdtype(dtype, np.complexfloating)): x = np.real(x) return x.astype(dtype) cast = utils.copy_docstring( 'tf.cast', lambda x, dtype, name=None: _cast(x, utils.numpy_dtype(dtype))) clip_by_value = utils.copy_docstring( 'tf.clip_by_value', lambda t, clip_value_min, clip_value_max, name=None: # pylint: disable=g-long-lambda np.clip(t, clip_value_min, clip_value_max)) constant = utils.copy_docstring( 'tf.constant', _constant) control_dependencies = utils.copy_docstring( 'tf.control_dependencies', _control_dependencies) convert_to_tensor = utils.copy_docstring(