示例#1
0
def _bincount(
        arr,
        weights=None,
        minlength=None,
        maxlength=None,  # pylint: disable=unused-argument
        dtype=np.int32,
        name=None):  # pylint: disable=unused-argument
    """Counts number of occurences of each value in `arr`."""
    # TODO(https://github.com/google/jax/issues/5719): Use np.bincount directly?
    if not JAX_MODE:
        return np.bincount(arr, weights,
                           minlength).astype(utils.numpy_dtype(dtype))

    dtype = utils.numpy_dtype(dtype)
    num_buckets = (np.max(arr) + 1) if np.size(arr) else 0
    if minlength is not None and maxlength is not None and minlength == maxlength:
        # In the case where we can use minlength directly, this helps avoids the
        # use of an abstract value, which prevents JAX JIT.
        num_buckets = minlength
    else:
        if minlength is not None:
            num_buckets = np.maximum(num_buckets, minlength)
        if maxlength is not None:
            num_buckets = np.minimum(num_buckets, maxlength)
    one_hots = one_hot(arr, num_buckets)
    # Reduce over every dimension except the last one.
    axes = tuple(range(0, one_hots.ndim - 1))
    if weights is not None:
        return np.sum(one_hots * weights[..., np.newaxis],
                      axis=axes).astype(dtype)
    return np.sum(one_hots, axis=axes).astype(dtype)
示例#2
0
def _convert_to_tensor(value, dtype=None, dtype_hint=None, name=None):  # pylint: disable=unused-argument
  """Emulates tf.convert_to_tensor."""
  dtype = utils.numpy_dtype(dtype)
  dtype_hint = utils.numpy_dtype(dtype_hint)
  if is_tensor(value) and not isinstance(value, Variable):
    if dtype is not None:
      # In NumPy mode, we are lenient on the dtype compatibility check because
      # some codepaths rely on flexible conversion from int/float64 to 32.
      if JAX_MODE and value.dtype != dtype:
        raise TypeError(('Tensor conversion requested dtype {} for array with '
                         'dtype {}: {}').format(dtype, value.dtype, value))
      return value.astype(dtype)
    return value

  conversion_func = tensor_conversion_registry.get(type(value),
                                                   _default_convert_to_tensor)
  ret = None
  if dtype is None and dtype_hint is not None:
    try:
      ret = conversion_func(value, dtype=dtype_hint)
    except (TypeError, ValueError):
      pass

  if ret is None:
    ret = conversion_func(value, dtype=dtype)
  return ret
示例#3
0
def _bincount(
        arr,
        weights=None,
        minlength=None,
        maxlength=None,  # pylint: disable=unused-argument
        dtype=np.int32,
        name=None):  # pylint: disable=unused-argument
    """Counts number of occurences of each value in `arr`."""
    if not JAX_MODE:
        return np.bincount(arr, weights,
                           minlength).astype(utils.numpy_dtype(dtype))

    dtype = utils.numpy_dtype(dtype)
    num_buckets = np.max(arr) + 1
    if minlength is not None:
        num_buckets = np.maximum(num_buckets, minlength)
    if maxlength is not None:
        num_buckets = np.minimum(num_buckets, maxlength)
    one_hots = one_hot(arr, num_buckets)
    # Reduce over every dimension except the last one.
    axes = tuple(range(0, one_hots.ndim - 1))
    if weights is not None:
        return np.sum(one_hots * weights[..., np.newaxis],
                      axis=axes).astype(dtype)
    return np.sum(one_hots, axis=axes).astype(dtype)
示例#4
0
def _convert_to_tensor(value, dtype=None, dtype_hint=None, name=None):  # pylint: disable=unused-argument
    """Emulates tf.convert_to_tensor."""
    assert not tf.is_tensor(value), value
    if isinstance(value, np.ndarray):
        if dtype is not None:
            dtype = utils.numpy_dtype(dtype)
            # if np.result_type(value, dtype) != dtype:
            #   raise ValueError('Expected dtype {} but got {} with dtype {}.'.format(
            #       dtype, value, value.dtype))
            return value.astype(dtype)
        return value
    if isinstance(value, TensorShape):
        value = [int(d) for d in value.as_list()]
    if dtype is None and dtype_hint is not None:
        dtype_hint = utils.numpy_dtype(dtype_hint)
        value = np.array(value)
        if np.size(value):
            # Match TF behavior, which won't downcast e.g. float to int.
            if np.issubdtype(value.dtype, np.complexfloating):
                if not np.issubdtype(dtype_hint, np.complexfloating):
                    return value
            if np.issubdtype(value.dtype, np.floating):
                if not np.issubdtype(dtype_hint, np.floating):
                    return value
            if np.issubdtype(value.dtype, np.integer):
                if not np.issubdtype(dtype_hint, np.integer):
                    return value
        return value.astype(dtype_hint)
    return np.array(value, dtype=utils.numpy_dtype(dtype or dtype_hint))
示例#5
0
def _convert_to_tensor(value, dtype=None, dtype_hint=None, name=None):  # pylint: disable=unused-argument
    """Emulates tf.convert_to_tensor."""
    assert not tf.is_tensor(value), value
    if is_tensor(value):
        if dtype is not None:
            dtype = utils.numpy_dtype(dtype)
            # if np.result_type(value, dtype) != dtype:
            #   raise ValueError('Expected dtype {} but got {} with dtype {}.'.format(
            #       dtype, value, value.dtype))
            return value.astype(dtype)
        return value
    if isinstance(value, Dimension):
        value = _dimension_value(value)
    elif isinstance(value, TensorShape):
        value = value.as_list()
    # In JAX mode, onp.ndarray/onp.generic are not identified as Tensor's.
    # By default, use the dtype of the values passed in.
    elif hasattr(value, 'dtype'):
        if dtype is not None:
            dtype = utils.numpy_dtype(dtype)
            return np.array(value).astype(dtype)
        return np.array(value)
    if dtype is None and dtype_hint is not None:
        dtype_hint = utils.numpy_dtype(dtype_hint)
        value = np.array(value)
        if np.size(value):
            # Match TF behavior, which won't downcast e.g. float to int.
            if np.issubdtype(value.dtype, np.complexfloating):
                if not np.issubdtype(dtype_hint, np.complexfloating):
                    return value
            if np.issubdtype(value.dtype, np.floating):
                if not (np.issubdtype(dtype_hint, np.floating)
                        or np.issubdtype(dtype_hint, np.complexfloating)):
                    return value
            if np.issubdtype(value.dtype, np.integer):
                if not (np.issubdtype(dtype_hint, np.integer)
                        or np.issubdtype(dtype_hint, np.floating)
                        or np.issubdtype(dtype_hint, np.complexfloating)):
                    return value
        return value.astype(dtype_hint)

    np_value = np.array(value, dtype=utils.numpy_dtype(dtype or dtype_hint))
    if np.issubdtype(np_value.dtype, np.object_):
        raise ValueError('Numpy `object`s cannot be converted to `Tensor`s.')
    # We have no hints. By default JAX (in x64 mode) and Numpy default to
    # {int64,float64} which does not match with TF's default.
    if dtype is None and dtype_hint is None:
        # If the integer doesn't fit in int32, return an int64. This matches TF.
        if isinstance(value, int):
            if value > onp.iinfo(onp.int32).max or value < onp.iinfo(
                    onp.int32).min:
                return np.array(value, dtype=np.int64)
        if np.issubdtype(np_value.dtype, np.floating):
            return np_value.astype(np.float32)
        if np.issubdtype(np_value.dtype, np.integer):
            return np_value.astype(np.int32)
    return np_value
示例#6
0
def _range(start, limit=None, delta=1, dtype=None, name='range'):  # pylint: disable=unused-argument
    dtype = utils.numpy_dtype(dtype or utils.common_dtype([start], np.int32))
    start = ops.convert_to_tensor(start, dtype=dtype)
    limit = None if limit is None else ops.convert_to_tensor(limit,
                                                             dtype=dtype)
    delta = ops.convert_to_tensor(delta, dtype=dtype)
    return np.arange(start, limit, delta).astype(dtype)
示例#7
0
 def __init__(self,
              dtype,
              size=None,
              dynamic_size=None,
              clear_after_read=None,
              tensor_array_name=None,
              handle=None,
              flow=None,
              infer_shape=True,
              element_shape=None,
              colocate_with_first_write_call=True,
              data=None,
              name=None):
     self._dtype = utils.numpy_dtype(dtype)
     if data is None:
         if JAX_MODE and size is not None and element_shape is not None:
             data = np.empty((size, ) + element_shape, dtype=self._dtype)
         else:
             data = [None] * (0 if size is None else int(size))
     self._data = data
     self._size = size
     self._dynamic_size = dynamic_size
     self._clear_after_read = clear_after_read
     self._tensor_array_name = tensor_array_name
     self._handle = handle
     self._flow = flow
     self._infer_shape = infer_shape
     self._element_shape = element_shape
     self._colocate_with_first_write_call = colocate_with_first_write_call
     self._name = name
示例#8
0
def _eye(num_rows, num_columns=None, batch_shape=None,
         dtype=np.float32, name=None):  # pylint: disable=unused-argument
  dt = utils.numpy_dtype(dtype)
  x = np.eye(num_rows, num_columns).astype(dt)
  if batch_shape is not None:
    x = x * np.ones(tuple(batch_shape) + (1, 1)).astype(dt)
  return x
示例#9
0
def _one_hot(  # pylint: disable=unused-argument
    indices,
    depth,
    on_value=None,
    off_value=None,
    axis=None,
    dtype=None,
    name=None):
  """One hot."""
  if on_value is None:
    on_value = 1
  if off_value is None:
    off_value = 0
  if dtype is None:
    dtype = utils.common_dtype([on_value, off_value], np.float32)
  else:
    dtype = utils.numpy_dtype(dtype)
  indices = np.array(indices)
  depth = np.array(depth)
  pred = abs(np.arange(depth, dtype=indices.dtype) -
             indices[..., np.newaxis]) > 0
  y_out = np.where(pred, np.array(off_value, dtype), np.array(on_value, dtype))
  if axis is not None:
    y_out = np.moveaxis(y_out, -1, axis)
  return y_out
示例#10
0
def _categorical(logits, num_samples, dtype=None, seed=None, name=None):  # pylint: disable=unused-argument
    rng = np.random if seed is None else np.random.RandomState(seed
                                                               & 0xffffffff)
    dtype = utils.numpy_dtype(dtype or np.int64)
    if not hasattr(logits, 'shape'):
        logits = np.array(logits, np.float32)
    n = logits.shape[-1]
    return rng.choice(n, p=_softmax(logits), size=num_samples).astype(dtype)
示例#11
0
 def __array__(self, dtype=None):
   if dtype is not None:
     dtype = utils.numpy_dtype(dtype)
     return self.__wrapped__.__array__(dtype)
   # Passing in dtype=None to __array__ has differing behavior in numpy.
   # When an `np.ndarray` has `.__array__(None)` invoked, the array is casted
   # to `float64`. Thus we handle this case separately.
   return self.__wrapped__.__array__()
示例#12
0
def _bincount(
        arr,
        weights=None,
        minlength=None,
        maxlength=None,  # pylint: disable=unused-argument
        dtype=tf.int32,
        name=None):  # pylint: disable=unused-argument
    return np.bincount(arr, weights,
                       minlength).astype(utils.numpy_dtype(dtype))
示例#13
0
def _histogram_fixed_width(values,
                           value_range,
                           nbins=100,
                           dtype=np.int32,
                           name=None):
    """Numpy implementation of `tf.histogram_fixed_width`."""
    del name
    return np.histogram(values, bins=nbins,
                        range=value_range)[0].astype(utils.numpy_dtype(dtype))
示例#14
0
def _categorical_jax(logits, num_samples, dtype=None, seed=None, name=None):  # pylint: disable=unused-argument
  dtype = utils.numpy_dtype(dtype or np.int64)
  if not hasattr(logits, 'shape') or not hasattr(logits, 'dtype'):
    logits = np.array(logits, np.float32)
  import jax.random as jaxrand  # pylint: disable=g-import-not-at-top
  if seed is None:
    raise ValueError('Must provide PRNGKey to sample in JAX.')
  z = jaxrand.gumbel(
      key=seed, shape=logits.shape + (num_samples,), dtype=logits.dtype)
  return np.argmax(np.expand_dims(logits, -1) + z, axis=-2).astype(dtype)
def _binomial(shape, seed, counts, probs, output_dtype=np.int32, name=None):  # pylint: disable=unused-argument
  rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff)
  invalid_count = (np.int64(counts) < 0) != (counts < 0)
  if np.any(invalid_count):
    raise ValueError('int64 overflow: {} -> {}'.format(
        counts[np.where(invalid_count)],
        np.int64(counts)[np.where(invalid_count)]))
  probs = np.where(counts > 0, probs, 0)
  samps = rng.binomial(np.int64(counts), np.float64(probs), shape)
  return samps.astype(utils.numpy_dtype(output_dtype))
示例#16
0
def _eye(num_rows,
         num_columns=None,
         batch_shape=None,
         dtype=tf.float32,
         name=None):  # pylint: disable=unused-argument
    dt = utils.numpy_dtype(dtype)
    x = np.eye(num_rows, num_columns).astype(dt)
    if batch_shape is not None:
        x = x * np.ones(np.concatenate([batch_shape, [1, 1]],
                                       axis=0)).astype(dt)
    return x
示例#17
0
def _histogram_fixed_width_bins(values, value_range, nbins=100, dtype=np.int32,
                                name=None):
  """Numpy implementation of `tf.histogram_fixed_width_bins`."""
  del name
  nbins_float = np.array(nbins, values.dtype)
  scaled_values = truediv(
      values - value_range[0], value_range[1] - value_range[0])
  indices = floor(nbins_float * scaled_values)
  indices = clip_by_value(indices, 0, nbins_float - 1).astype(
      utils.numpy_dtype(dtype))
  return indices
示例#18
0
def _unique(x, out_idx=tf.int32, name=None):  # pylint: disable=unused-argument
  """Numpy implementation of `tf.unique`."""
  x = np.array(x)
  if len(x.shape) != 1:
    raise tf.errors.InvalidArgumentError('unique expects a 1D vector.')
  y, idx = np.unique(x,
                     return_index=True,
                     return_inverse=False,
                     return_counts=False,
                     axis=None)
  idx = idx.astype(utils.numpy_dtype(out_idx))
  return _UniqueOutput(y=y, idx=idx)
示例#19
0
def _confusion_matrix(
    labels, predictions, num_classes=None, weights=None,
    dtype=np.int32, name=None):
  """Return confusion matrix between predictions and labels."""
  del name
  if num_classes is None:
    num_classes = np.maximum(np.max(predictions), np.max(labels)) + 1
  cmatrix = np.zeros([num_classes, num_classes], dtype=utils.numpy_dtype(dtype))
  if weights is None:
    weights = 1
  if not JAX_MODE:
    np.add.at(cmatrix, [labels, predictions], weights)
    return cmatrix
  return jax.ops.index_add(cmatrix, [labels, predictions], weights)
示例#20
0
def _range(start, limit=None, delta=1, dtype=None, name='range'):  # pylint: disable=unused-argument
  """Emulates tf.range."""
   # Emulating dtype inference logic from tf.range
  dtype = utils.numpy_dtype(dtype)
  start = ops.convert_to_tensor(start, dtype=dtype)
  limit = None if limit is None else ops.convert_to_tensor(limit, dtype=dtype)
  delta = ops.convert_to_tensor(delta, dtype=dtype)
  if dtype is None:
    dtype_hierarchy = [np.int32, np.int64, np.float32, np.float64]
    inferred_dtype = max([arg.dtype for arg in [start, limit, delta]
                          if arg is not None],
                         key=dtype_hierarchy.index)
  else:
    inferred_dtype = dtype
  return np.arange(start, limit, delta).astype(inferred_dtype)
示例#21
0
 def __init__(self,
              initial_value=None,
              trainable=True,
              validate_shape=True,
              caching_device=None,
              name=None,
              variable_def=None,
              dtype=None,
              import_scope=None,
              constraint=None,
              shape=None):
     assert constraint is None
     v = convert_to_tensor(initial_value)
     if dtype is not None:
         v = v.astype(utils.numpy_dtype(dtype))
     super(NumpyVariable, self).__init__(v)
     self.initializer = None
示例#22
0
def _range(start, limit=None, delta=1, dtype=None, name='range'):  # pylint: disable=unused-argument
  """Emulates tf.range."""
  # Emulating dtype inference logic from tf.range
  dtype = utils.numpy_dtype(dtype)
  infer_dtype = lambda t: ops.convert_to_tensor(t, dtype=dtype).dtype
  # We must keep start, limit, and delta static np.array since they determine
  # the size of the result array, which JAX requires to be static.
  start = onp.array(start, dtype=infer_dtype(start))
  limit = None if limit is None else onp.array(limit, dtype=infer_dtype(limit))
  delta = onp.array(delta, dtype=infer_dtype(delta))
  if dtype is None:
    dtype_hierarchy = [np.int32, np.int64, np.float32, np.float64]
    inferred_dtype = max([arg.dtype for arg in [start, limit, delta]
                          if arg is not None],
                         key=dtype_hierarchy.index)
  else:
    inferred_dtype = dtype
  return np.arange(start, limit, delta).astype(inferred_dtype)
示例#23
0
def _lu(input, output_idx_type=tf.int32, name=None):  # pylint: disable=redefined-builtin
  """Returns Lu(lu, p), as TF does."""
  del name
  if JAX_MODE:  # But JAX uses XLA, which can do a batched factorization.
    lu_out, pivots = scipy_linalg.lu_factor(input)
    from jax import lax_linalg  # pylint: disable=g-import-not-at-top
    return Lu(lu_out,
              lax_linalg.lu_pivots_to_permutation(pivots, lu_out.shape[-1]))
  # Scipy can't batch, so we must do so manually.
  nbatch = int(np.prod(input.shape[:-2]))
  dim = input.shape[-1]
  flat_mat = input.reshape(nbatch, dim, dim)
  flat_lu = np.empty((nbatch, dim, dim), dtype=input.dtype)
  flat_piv = np.empty((nbatch, dim), dtype=utils.numpy_dtype(output_idx_type))
  if np.size(flat_lu):  # Avoid non-empty batches of empty matrices.
    for i, mat in enumerate(flat_mat):
      lu_out, pivots = scipy_linalg.lu_factor(mat)
      flat_lu[i] = lu_out
      flat_piv[i] = _lu_pivot_to_permutation(pivots, flat_lu.shape[-1])
  return Lu(flat_lu.reshape(*input.shape), flat_piv.reshape(*input.shape[:-1]))
示例#24
0
 def __init__(self,
              dtype,
              size=None,
              dynamic_size=None,
              clear_after_read=None,
              tensor_array_name=None,
              handle=None,
              flow=None,
              infer_shape=True,
              element_shape=None,
              colocate_with_first_write_call=True,
              data=None,
              name=None):
     self._dtype = utils.numpy_dtype(dtype)
     if data is None:
         if JAX_MODE and size is not None and element_shape is not None:
             data = np.empty((size, ) + tuple(element_shape),
                             dtype=self._dtype)
         # Can be useful for finding failure cases in JAX TensorArray-using code.
         # elif JAX_MODE:
         #   raise ValueError(
         #       'Missing shape argument: size {} element_shape {}'.format(
         #           size, element_shape))
         else:
             data = [None] * (0 if size is None else int(size))
     self._data = data
     self._size = size
     self._dynamic_size = dynamic_size
     self._clear_after_read = clear_after_read
     self._tensor_array_name = tensor_array_name
     self._handle = handle
     self._flow = flow
     self._infer_shape = infer_shape
     self._element_shape = element_shape
     self._colocate_with_first_write_call = colocate_with_first_write_call
     self._name = name
示例#25
0
 def __init__(self,
              dtype,
              size=None,
              dynamic_size=None,
              clear_after_read=None,
              tensor_array_name=None,
              handle=None,
              flow=None,
              infer_shape=True,
              element_shape=None,
              colocate_with_first_write_call=True,
              name=None):
     self._data = [None] * (size if size else 0)
     self._dtype = utils.numpy_dtype(dtype)
     self._size = size
     self._dynamic_size = dynamic_size
     self._clear_after_read = clear_after_read
     self._tensor_array_name = tensor_array_name
     self._handle = handle
     self._flow = flow
     self._infer_shape = infer_shape
     self._element_shape = element_shape
     self._colocate_with_first_write_call = colocate_with_first_write_call
     self._name = name
示例#26
0
def _constant(value, dtype=None, shape=None, name='Const'):  # pylint: disable=unused-argument
    x = np.array(value,
                 dtype=None if dtype is None else utils.numpy_dtype(dtype))
    if shape is None:
        return x
    return np.reshape(x, shape)
示例#27
0
        return source


broadcast_dynamic_shape = utils.copy_docstring(tf.broadcast_dynamic_shape,
                                               _broadcast_static_shape)

broadcast_static_shape = utils.copy_docstring(tf.broadcast_static_shape,
                                              _broadcast_static_shape)

broadcast_to = utils.copy_docstring(
    tf.broadcast_to,
    lambda input, shape, name=None: np.broadcast_to(input, shape))

cast = utils.copy_docstring(
    tf.cast,
    lambda x, dtype, name=None: np.array(x).astype(utils.numpy_dtype(dtype)))

clip_by_value = utils.copy_docstring(
    tf.clip_by_value,
    lambda t, clip_value_min, clip_value_max, name=None:  # pylint: disable=g-long-lambda
    np.clip(t, clip_value_min, clip_value_max))

constant = utils.copy_docstring(tf.constant, _constant)

control_dependencies = utils.copy_docstring(tf.control_dependencies,
                                            _control_dependencies)

convert_to_tensor = utils.copy_docstring(tf.convert_to_tensor,
                                         _convert_to_tensor)

custom_gradient = utils.copy_docstring(tf.custom_gradient, lambda f: f)
示例#28
0
def _zeros_like(input, dtype=None, name=None):  # pylint: disable=redefined-builtin,unused-argument
  return np.zeros_like(input, dtype=utils.numpy_dtype(dtype))
示例#29
0
meshgrid = utils.copy_docstring(
    'tf.meshgrid',
    np.meshgrid)

norm = utils.copy_docstring(
    'tf.norm',
    norm)

one_hot = utils.copy_docstring(
    'tf.one_hot',
    _one_hot)

ones = utils.copy_docstring(
    'tf.ones',
    lambda shape, dtype=np.float32, name=None: np.ones(  # pylint: disable=g-long-lambda
        shape, utils.numpy_dtype(dtype)))

ones_like = utils.copy_docstring(
    'tf.ones_like',
    _ones_like)

pad = utils.copy_docstring(
    'tf.pad',
    _pad)

range = utils.copy_docstring(  # pylint: disable=redefined-builtin
    'tf.range',
    _range)

rank = utils.copy_docstring(
    'tf.rank',
示例#30
0
broadcast_to = utils.copy_docstring(
    'tf.broadcast_to',
    lambda input, shape, name=None: np.broadcast_to(input, shape))


def _cast(x, dtype):
  x = np.asarray(x)
  if (np.issubdtype(x.dtype, np.complexfloating) and
      not np.issubdtype(dtype, np.complexfloating)):
    x = np.real(x)
  return x.astype(dtype)


cast = utils.copy_docstring(
    'tf.cast',
    lambda x, dtype, name=None: _cast(x, utils.numpy_dtype(dtype)))

clip_by_value = utils.copy_docstring(
    'tf.clip_by_value',
    lambda t, clip_value_min, clip_value_max, name=None:  # pylint: disable=g-long-lambda
    np.clip(t, clip_value_min, clip_value_max))

constant = utils.copy_docstring(
    'tf.constant',
    _constant)

control_dependencies = utils.copy_docstring(
    'tf.control_dependencies',
    _control_dependencies)

convert_to_tensor = utils.copy_docstring(