示例#1
0
def _convert_to_tensor(value, dtype=None, dtype_hint=None, name=None):  # pylint: disable=unused-argument
    """Emulates tf.convert_to_tensor."""
    assert not tf.is_tensor(value), value
    if isinstance(value, np.ndarray):
        if dtype is not None:
            dtype = utils.numpy_dtype(dtype)
            # if np.result_type(value, dtype) != dtype:
            #   raise ValueError('Expected dtype {} but got {} with dtype {}.'.format(
            #       dtype, value, value.dtype))
            return value.astype(dtype)
        return value
    if isinstance(value, TensorShape):
        value = [int(d) for d in value.as_list()]
    if dtype is None and dtype_hint is not None:
        dtype_hint = utils.numpy_dtype(dtype_hint)
        value = np.array(value)
        if np.size(value):
            # Match TF behavior, which won't downcast e.g. float to int.
            if np.issubdtype(value.dtype, np.complexfloating):
                if not np.issubdtype(dtype_hint, np.complexfloating):
                    return value
            if np.issubdtype(value.dtype, np.floating):
                if not np.issubdtype(dtype_hint, np.floating):
                    return value
            if np.issubdtype(value.dtype, np.integer):
                if not np.issubdtype(dtype_hint, np.integer):
                    return value
        return value.astype(dtype_hint)
    return np.array(value, dtype=utils.numpy_dtype(dtype or dtype_hint))
示例#2
0
def _bincount(
        arr,
        weights=None,
        minlength=None,
        maxlength=None,  # pylint: disable=unused-argument
        dtype=tf.int32,
        name=None):  # pylint: disable=unused-argument
    return np.bincount(arr, weights,
                       minlength).astype(utils.numpy_dtype(dtype))
示例#3
0
def _eye(num_rows,
         num_columns=None,
         batch_shape=None,
         dtype=tf.float32,
         name=None):  # pylint: disable=unused-argument
    dt = utils.numpy_dtype(dtype)
    x = np.eye(num_rows, num_columns).astype(dt)
    if batch_shape is not None:
        x = x * np.ones(tuple(batch_shape) + (1, 1)).astype(dt)
    return x
示例#4
0
def _categorical(logits, num_samples, dtype=None, seed=None, name=None):  # pylint: disable=unused-argument
    rng = np.random if seed is None else np.random.RandomState(seed
                                                               & 0xffffffff)
    dtype = utils.numpy_dtype(dtype or np.int64)
    if not hasattr(logits, 'shape'):
        logits = np.array(logits, np.float32)
    probs = _softmax(logits)
    n = logits.shape[-1]
    return np.apply_along_axis(lambda p: rng.choice(n, p=p, size=num_samples),
                               1, probs)
示例#5
0
def _categorical_jax(logits, num_samples, dtype=None, seed=None, name=None):  # pylint: disable=unused-argument
    dtype = utils.numpy_dtype(dtype or np.int64)
    if not hasattr(logits, 'shape') or not hasattr(logits, 'dtype'):
        logits = np.array(logits, np.float32)
    import jax.random as jaxrand  # pylint: disable=g-import-not-at-top
    if seed is None:
        raise ValueError('Must provide PRNGKey to sample in JAX.')
    z = jaxrand.gumbel(key=seed,
                       shape=logits.shape + (num_samples, ),
                       dtype=logits.dtype)
    return np.argmax(np.expand_dims(logits, -1) + z, axis=-2).astype(dtype)
示例#6
0
def _lu(input, output_idx_type=tf.int32, name=None):  # pylint: disable=redefined-builtin
    """Returns Lu(lu, p), as TF does."""
    del name
    if JAX_MODE:  # But JAX uses XLA, which can do a batched factorization.
        lu_out, pivots = scipy_linalg.lu_factor(input)
        from jax import lax_linalg  # pylint: disable=g-import-not-at-top
        return Lu(
            lu_out,
            lax_linalg.lu_pivots_to_permutation(pivots, lu_out.shape[-1]))
    # Scipy can't batch, so we must do so manually.
    nbatch = int(np.prod(input.shape[:-2]))
    dim = input.shape[-1]
    flat_mat = input.reshape(nbatch, dim, dim)
    flat_lu = np.empty((nbatch, dim, dim), dtype=input.dtype)
    flat_piv = np.empty((nbatch, dim),
                        dtype=utils.numpy_dtype(output_idx_type))
    if np.size(flat_lu):  # Avoid non-empty batches of empty matrices.
        for i, mat in enumerate(flat_mat):
            lu_out, pivots = scipy_linalg.lu_factor(mat)
            flat_lu[i] = lu_out
            flat_piv[i] = _lu_pivot_to_permutation(pivots, flat_lu.shape[-1])
    return Lu(flat_lu.reshape(*input.shape),
              flat_piv.reshape(*input.shape[:-1]))
示例#7
0
def _constant(value, dtype=None, shape=None, name='Const'):  # pylint: disable=unused-argument
    x = np.array(value,
                 dtype=None if dtype is None else utils.numpy_dtype(dtype))
    if shape is None:
        return x
    return np.reshape(x, shape)
示例#8
0
                                               _broadcast_static_shape)

broadcast_static_shape = utils.copy_docstring(tf.broadcast_static_shape,
                                              _broadcast_static_shape)

broadcast_static_shape_as_tensorshape = utils.copy_docstring(
    tf.broadcast_static_shape,
    functools.partial(_broadcast_static_shape, as_tensorshape=True))

broadcast_to = utils.copy_docstring(
    tf.broadcast_to,
    lambda input, shape, name=None: onp.broadcast_to(input, shape))

cast = utils.copy_docstring(
    tf.cast,
    lambda x, dtype, name=None: np.array(x).astype(utils.numpy_dtype(dtype)))

clip_by_value = utils.copy_docstring(
    tf.clip_by_value,
    lambda t, clip_value_min, clip_value_max, name=None:  # pylint: disable=g-long-lambda
    np.clip(t, clip_value_min, clip_value_max))

constant = utils.copy_docstring(tf.constant, _constant)

control_dependencies = utils.copy_docstring(tf.control_dependencies,
                                            _control_dependencies)

convert_to_tensor = utils.copy_docstring(tf.convert_to_tensor,
                                         _convert_to_tensor)

custom_gradient = utils.copy_docstring(tf.custom_gradient, lambda f: f)
示例#9
0

def _top_k(input, k=1, sorted=True, name=None):  # pylint: disable=unused-argument,redefined-builtin
    raise NotImplementedError


# --- Begin Public Functions --------------------------------------------------

abs = utils.copy_docstring(  # pylint: disable=redefined-builtin
    tf.math.abs,
    lambda x, name=None: np.abs(x))

accumulate_n = utils.copy_docstring(
    tf.math.accumulate_n,
    lambda inputs, shape=None, tensor_dtype=None, name=None: (  # pylint: disable=g-long-lambda
        sum(map(np.array, inputs)).astype(utils.numpy_dtype(tensor_dtype))))

acos = utils.copy_docstring(tf.math.acos, lambda x, name=None: np.arccos(x))

acosh = utils.copy_docstring(tf.math.acosh, lambda x, name=None: np.arccosh(x))

add = utils.copy_docstring(tf.math.add, lambda x, y, name=None: np.add(x, y))

add_n = utils.copy_docstring(
    tf.math.add_n, lambda inputs, name=None: sum(map(np.array, inputs)))

angle = utils.copy_docstring(tf.math.angle,
                             lambda input, name=None: np.angle(input))

argmax = utils.copy_docstring(
    tf.math.argmax,
示例#10
0
def _zeros_like(input, dtype=None, name=None):  # pylint: disable=redefined-builtin
    s = _shape(input)
    if isinstance(s, (np.ndarray, onp.generic)):
        return np.zeros(s, utils.numpy_dtype(dtype or input.dtype))
    return tf.zeros(s, dtype or s.dtype, name)
示例#11
0
def _size(input, out_type=tf.int32, name=None):  # pylint: disable=redefined-builtin, unused-argument
    return np.prod(np.array(input).shape).astype(utils.numpy_dtype(out_type))
示例#12
0
linspace = utils.copy_docstring(
    tf.linspace,
    lambda start, stop, num, name=None: (  # pylint: disable=g-long-lambda
        np.linspace(start, stop, num).astype(np.array(start).dtype)))

meshgrid = utils.copy_docstring(tf.meshgrid, np.meshgrid)

norm = utils.copy_docstring(tf.norm, norm)

one_hot = utils.copy_docstring(tf.one_hot, _one_hot)

ones = utils.copy_docstring(
    tf.ones,
    lambda shape, dtype=tf.float32, name=None: np.ones(  # pylint: disable=g-long-lambda
        shape, utils.numpy_dtype(dtype)))

ones_like = utils.copy_docstring(tf.ones_like, _ones_like)

pad = utils.copy_docstring(tf.pad, _pad)

range = utils.copy_docstring(  # pylint: disable=redefined-builtin
    tf.range,
    lambda start, limit=None, delta=1, dtype=None, name='range': (  # pylint: disable=g-long-lambda
        np.arange(start, limit, delta, utils.numpy_dtype(dtype))))

rank = utils.copy_docstring(tf.rank,
                            lambda input, name=None: np.array(input).ndim)  # pylint: disable=redefined-builtin,g-long-lambda

reshape = utils.copy_docstring(
    tf.reshape, lambda tensor, shape, name=None: np.reshape(tensor, shape))