Пример #1
0
 def testBooleanIndexingDynamicShape(self):
     x = onp.zeros(3)
     i = onp.array([True, True, False])
     ans = x[i]
     expected = jnp.asarray(x)[i]
     self.assertAllClose(ans, expected, check_dtypes=True)
Пример #2
0
 def f(a, b):
     y = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
     if has_aux:
         return y, tf_np.asarray(1)
     else:
         return y
Пример #3
0
 def testBooleanIndexingList1D(self):
     idx = [True, True, False]
     x = jnp.asarray(onp.arange(3))
     ans = x[idx]
     expected = onp.arange(3)[idx]
     self.assertAllClose(ans, expected, check_dtypes=False)
Пример #4
0
 def testBooleanIndexingList2DBroadcast(self):
     idx = [True, True, False, True]
     x = onp.arange(8).reshape(4, 2)
     ans = jnp.asarray(x)[idx]
     expected = x[idx]
     self.assertAllClose(ans, expected, check_dtypes=False)
Пример #5
0
def uniform(rng, shape, dtype):
  if np.issubdtype(dtype, np.integer):
    minval = None
  else:
    minval = 0
  return tf_np.asarray(rng.uniform(shape=shape, dtype=dtype, minval=minval))
Пример #6
0
 def testPrng(self):
   self.assertAllEqual(tf_np.asarray(123, np.int64), extensions.prng(123))
Пример #7
0
 def gather(a):
     return tf_np.asarray(tf.gather(a.data, idxs, batch_dims=rank - 1))
Пример #8
0
def loss_fn(params, inputs, targets):
  predicted = params[0] * inputs + params[1]
  loss = tf.reduce_mean(input_tensor=tf.square(predicted - targets))
  return tf_np.asarray(loss)
Пример #9
0
def expit(x):
    """Compute 1 / (1 + exp(-x))."""
    return tf_np.asarray(tf.math.sigmoid(x.data))
Пример #10
0
def erf(x):
    """Computes the Gauss error function of x element-wise."""
    return tf_np.asarray(tf.math.erf(x.data))
Пример #11
0
def scan(f, init, xs, length=None, reverse=False):
    """Scan a function over leading array axes while carrying along state.

  See the docstring of `jax.lax.scan`
  (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) for
  details.

  Args:
    f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
      that ``f`` accepts two arguments where the first is a value of the loop
      carry and the second is a slice of ``xs`` along its leading axis, and that
      ``f`` returns a pair where the first element represents a new value for
      the loop carry and the second represents a slice of the output. Note that
      the input and output carry must have the same dtype.
    init: an initial loop carry value of type ``c``, which can be a scalar,
      array, or any pytree (nested Python tuple/list/dict) thereof, representing
      the initial loop carry value. This value must have the same structure as
      the first element of the pair returned by ``f``.
    xs: the value of type ``[a]`` over which to scan along the leading axis,
      where ``[a]`` can be an array or any pytree (nested Python
      tuple/list/dict) thereof with consistent leading axis sizes.
    length: optional integer specifying the number of loop iterations, which
      must agree with the sizes of leading axes of the arrays in ``xs`` (but can
      be used to perform scans where no input ``xs`` are needed).
    reverse: optional boolean specifying whether to run the scan iteration
      forward (the default) or in reverse, equivalent to reversing the leading
      axes of the arrays in both ``xs`` and in ``ys``.

  Returns:
    A pair of type ``(c, [b])`` where the first element represents the final
    loop carry value and the second element represents the stacked outputs of
    the second output of ``f`` when scanned over the leading axis of the inputs.
  """
    init, xs = tf.nest.map_structure(
        lambda x: tf_np.asarray(x) if x is not None else None, (init, xs))
    init, xs = _np_to_tf((init, xs))

    def get_length(x):
        if x is None:
            return None
        if x.shape.rank == 0:
            raise ValueError(
                "Some array in `xs` doesn't have a leading dimension")
        return x.shape[0]

    lengths = tf.nest.flatten(tf.nest.map_structure(get_length, xs))
    for l in lengths:
        if l is not None:
            if length is None:
                length = l
            elif length != l:
                raise ValueError(
                    "There are two different leading-dimension lengths: "
                    f"{length} and {l}")
    if length is None:
        raise ValueError(
            "Can't determine length. Please set the `length` argument.")
    xs_ta = tf.nest.map_structure(
        lambda t: (
            tf.TensorArray(t.dtype, size=0, dynamic_size=True).unstack(t)  # pylint: disable=g-long-lambda
            if t is not None else None),
        xs)

    def body(i, carry, ys_ta):
        if reverse:
            i_ = length - 1 - i
        else:
            i_ = i
        xs = tf.nest.map_structure(
            lambda x_ta: x_ta.read(i_) if x_ta is not None else None, xs_ta)
        carry, ys = _np_to_tf(f(*_tf_to_np((carry, xs))))
        ys_ta = tf.nest.map_structure(
            lambda y_ta, y: (y_ta.write(i_, y)
                             if y is not None else y_ta), ys_ta, ys)
        i = i + 1
        return i, carry, ys_ta

    xs_spec = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype)
        if t is not None else None, xs)
    _, ys_spec = eval_on_shapes(f)(init, xs_spec)
    # ys_ta can't contain None because tf.while_loop doesn't allow None in
    # loop_vars.
    ys_ta = tf.nest.map_structure(
        lambda y: tf.TensorArray(
            y.dtype if y is not None else tf.float32,
            size=0,  # pylint: disable=g-long-lambda
            dynamic_size=True),
        ys_spec)
    _, carry, ys_ta = tf.while_loop(lambda i, *_: i < length, body,
                                    (0, init, ys_ta))

    def _stack(a, spec):
        if spec is None:
            return None
        a = a.stack()
        a.set_shape((length, ) + a.shape[1:])
        return a

    ys = tf.nest.map_structure(_stack, ys_ta, ys_spec)
    return _tf_to_np((carry, ys))
Пример #12
0
 def f(x):
     if isinstance(x, (tf.Tensor, tf.IndexedSlices)):
         return tf_np.asarray(x)
     else:
         return x
Пример #13
0
 def fun(x, indexer_with_dummies):
     idx = type(indexer)(subvals(indexer_with_dummies, substitutes))
     return jnp.asarray(x)[idx]
Пример #14
0
 def testAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer):
     rng = rng_factory()
     args_maker = lambda: [rng(shape, dtype), indexer]
     fun = lambda x, idx: jnp.asarray(x)[idx]
     self._CompileAndCheck(fun, args_maker, check_dtypes=True)
Пример #15
0
    def testScanGrad(self, jit_grad, jit_scan, jit_f):
        rng = np.random.RandomState(0)

        d = rng.randn(2)

        def f(c, a):
            assert a.shape == (3, )
            assert c.shape == (4, )
            b = (tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.sin(c)) +
                 tf_np.sum(tf_np.sin(d)))
            c = tf_np.sin(c * b)
            assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
            return c, b

        if jit_f:
            f = extensions.jit(f)

        if jit_scan == "no_xla":
            scan = extensions.jit(extensions.scan, static_argnums=(0, ))
        elif jit_scan == "xla_forced_compile":
            # TODO(b/187107596): Remove `skipTest`
            self.skipTest(
                "Taking gradients of `jit(scan, experimental_compile=True)` triggers "
                "'Support for TensorList crossing the XLA/TF boundary is not "
                "implemented' error")
            # `xla_forced_compile=True` doesn't support gradients, so we use
            # `experimental_compile=True`.
            scan = extensions.jit(extensions.scan,
                                  static_argnums=(0, ),
                                  experimental_compile=True)
        else:
            scan = extensions.scan

        xs = tf_np.asarray(rng.randn(5, 3))
        c = tf_np.asarray(rng.randn(4))

        def losses(scan, c, xs):
            c, ys = scan(f, c, xs)
            return tf_np.concatenate(
                tf.nest.flatten(
                    tf.nest.map_structure(lambda a: tf_np.reshape(a, [-1]),
                                          (c, ys))))

        def loss(scan, c, xs):
            return tf_np.sum(losses(scan, c, xs))

        def grad_origin(c, xs):
            return extensions.grad(functools.partial(loss, scan))(c, xs)

        if jit_grad == "no_xla":
            grad_jit = extensions.jit(grad_origin)
        elif jit_grad == "xla_forced_compile":
            grad_jit = extensions.jit(grad_origin, xla_forced_compile=True)
        else:
            grad_jit = grad_origin

        ans = grad_jit(c, xs)
        expected = extensions.grad(functools.partial(loss, scan_reference))(c,
                                                                            xs)
        self.assertDTypesEqual(expected, ans)
        self.assertAllClose(expected, ans)

        theoretical, numerical = tf.test.compute_gradient(
            to_tf_fn(functools.partial(losses, scan)), (c, xs))
        self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=3e-4)