Beispiel #1
0
    def test_hk_jit(self, module_fn: ModuleFn, shape, dtype, init):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def g(x, jit=False):
            mod = module_fn()
            if jit:
                mod = stateful.jit(mod)
            return mod(x)

        f = hk.transform_with_state(g)

        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=1e-4)

        # NOTE: We shard init/apply tests since some modules are expensive to jit
        # (e.g. ResNet50 takes ~60s to compile and we compile it twice per test).
        if init:
            jax.tree_map(assert_allclose,
                         jax.jit(f.init)(rng, x), f.init(rng, x, jit=True))

        else:
            params, state = f.init(rng, x)
            jax.tree_map(assert_allclose,
                         jax.jit(f.apply)(params, state, rng, x),
                         f.apply(params, state, rng, x, jit=True))
  def test_profiler_name_scopes(
      self,
      module_fn: descriptors.ModuleFn,
      shape: Shape,
      dtype: DType,
  ):
    if not hasattr(xla.xb, 'parameter'):
      self.skipTest('Need Jaxlib version > 0.1.45')

    rng = jax.random.PRNGKey(42)
    if jnp.issubdtype(dtype, jnp.integer):
      x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
    else:
      x = jax.random.uniform(rng, shape, dtype)

    def g(x, name_scopes=False):
      hk.experimental.profiler_name_scopes(enabled=name_scopes)
      mod = module_fn()
      return mod(x)

    f = hk.transform_with_state(g)

    assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5)

    params, state = f.init(rng, x)
    jax.tree_multimap(assert_allclose,
                      f.apply(params, state, rng, x),
                      f.apply(params, state, rng, x, name_scopes=True))

    # TODO(lenamartens): flip to True when default changes
    hk.experimental.profiler_name_scopes(enabled=False)
Beispiel #3
0
    def test_vmap(self, module_fn: ModuleFn, shape, dtype):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        # Expand our input since we will map over it.
        x = jnp.broadcast_to(x, (2, ) + x.shape)

        f = hk.transform_with_state(lambda x: module_fn()(x))  # pylint: disable=unnecessary-lambda
        f_mapped = hk.transform_with_state(
            lambda x: hk.vmap(lambda x: module_fn()(x), split_rng=False)(x))  # pylint: disable=unnecessary-lambda

        params, state = f_mapped.init(rng, x)

        # JAX vmap with explicitly unmapped params/state/rng. This should be
        # equivalent to `f_mapped.apply(..)` (since by default hk.vmap does not map
        # params/state/rng).
        v_apply = jax.vmap(f.apply,
                           in_axes=(None, None, None, 0),
                           out_axes=(0, None))

        module_type = descriptors.module_type(module_fn)
        atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL)
        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=atol)
        jax.tree_map(assert_allclose, f_mapped.apply(params, state, rng, x),
                     v_apply(params, state, rng, x))
Beispiel #4
0
    def test_optimize_rng_use_under_jit(self, module_fn: ModuleFn, shape,
                                        dtype):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def g(x):
            return module_fn()(x)

        f = hk.transform_with_state(hk.experimental.optimize_rng_use(g))

        module_type = descriptors.module_type(module_fn)
        atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL)
        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=atol)

        params, state = jax.jit(f.init)(rng, x)
        jax.tree_map(assert_allclose, (params, state), f.init(rng, x))

        if module_type in (hk.nets.VectorQuantizer,
                           hk.nets.VectorQuantizerEMA):
            # For stochastic modules just test apply runs.
            jax.device_get(jax.jit(f.apply)(params, state, rng, x))

        else:
            jax.tree_map(assert_allclose,
                         jax.jit(f.apply)(params, state, rng, x),
                         f.apply(params, state, rng, x))
Beispiel #5
0
    def test_hk_remat(self, module_fn: ModuleFn, shape, dtype):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def g(x, remat=False):
            mod = module_fn()
            if remat:
                mod = hk.remat(mod)
            out = mod(x)
            if isinstance(out, dict):
                out = out['loss']
            return jnp.mean(out)

        f = hk.transform_with_state(g)

        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=1e-5)

        grad_jax_remat = jax.grad(jax.remat(f.apply), has_aux=True)
        grad_hk_remat = jax.grad(functools.partial(f.apply, remat=True),
                                 has_aux=True)

        params, state = f.init(rng, x)
        jax.tree_map(assert_allclose, grad_jax_remat(params, state, rng, x),
                     grad_hk_remat(params, state, rng, x))
Beispiel #6
0
    def test_vmap(
        self,
        module_fn: ModuleFn,
        shape: Shape,
        dtype: DType,
    ):
        batch_size, shape = shape[0], shape[1:]
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            sample = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            sample = jax.random.uniform(rng, shape, dtype)
        batch = jnp.broadcast_to(sample, (batch_size, ) + sample.shape)

        def g(x):
            return module_fn()(x)

        f = hk.transform_with_state(g)

        # Ensure application under vmap is the same.
        params, state = f.init(rng, sample)
        v_apply = jax.vmap(f.apply, in_axes=(None, None, None, 0))
        jax.tree_multimap(
            lambda a, b: np.testing.assert_allclose(a, b, atol=DEFAULT_ATOL),
            f.apply(params, state, rng, batch),
            v_apply(params, state, rng, batch))
Beispiel #7
0
    def test_jit(
        self,
        module_fn: ModuleFn,
        shape: Shape,
        dtype: DType,
    ):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def g(x):
            return module_fn()(x)

        f = hk.transform_with_state(g)

        atol = CUSTOM_ATOL.get(module_type(module_fn), DEFAULT_ATOL)
        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=atol)

        # Ensure initialization under jit is the same.
        jax.tree_multimap(assert_allclose, f.init(rng, x),
                          jax.jit(f.init)(rng, x))

        # Ensure application under jit is the same.
        params, state = f.init(rng, x)
        jax.tree_multimap(assert_allclose, f.apply(params, state, rng, x),
                          jax.jit(f.apply)(params, state, rng, x))
Beispiel #8
0
def _get_matrix_parameters(params: Array) -> Array:
    """Get an NxN parameter matrix from per-particle parameters."""
    if isinstance(params, jnp.ndarray):
        if len(params.shape) == 1:
            # NOTE(schsam): get_parameter_matrix only supports additive parameters.
            return 0.5 * (params[:, jnp.newaxis] + params[jnp.newaxis, :])
        elif len(params.shape) == 0 or len(params.shape) == 2:
            return params
        else:
            raise NotImplementedError
    elif (isinstance(params, int) or isinstance(params, float)
          or jnp.issubdtype(params, jnp.integer)
          or jnp.issubdtype(params, jnp.floating)):
        return params
    else:
        raise NotImplementedError
Beispiel #9
0
def _get_bond_type_parameters(params, bond_type):
  """Get parameters for interactions for bonds indexed by a bond-type."""
  assert isinstance(bond_type, np.ndarray)
  assert len(bond_type.shape) == 1

  if isinstance(params, np.ndarray):
    if len(params.shape) == 1:
      return params[bond_type]
    elif len(params.shape) == 0:
      return params
    else:
      raise ValueError(
          'Params must be a scalar or a 1d array if using a bond-type lookup.')
  elif(isinstance(params, int) or isinstance(params, float) or
       np.issubdtype(params, np.integer) or np.issubdtype(params, np.floating)):
    return params
  raise NotImplementedError
Beispiel #10
0
def _to_complex(x):
    if np.issubdtype(x.dtype, np.complexfloating):
        return x
    dtype = dtypes.complex64

    if x.dtype == dtypes.float64:
        dtype = dtypes.complex128
    return _ops.cast(x, dtype)
 def preprocess_variate(self, rng, X):
     X = jnp.asarray(X)
     assert X.ndim <= 1, f"unexpected X.shape: {X.shape}"
     assert jnp.issubdtype(
         X.dtype, jnp.integer), f"expected an integer dtype, got {X.dtype}"
     low, high = float(self.space_orig.low), float(self.space_orig.high)
     return jax.nn.one_hot(
         jnp.floor((X - low) * self.num_bins / (high - low)), self.num_bins)
Beispiel #12
0
def _vjp(pars, forward_fn, v, vec, conjugate):

    # output dtype
    out_dtype = forward_scalar(pars, forward_fn, v[0, :]).dtype

    # convert the sensitivity to right dtype
    vec = jnp.asarray(vec, dtype=out_dtype)

    if tree_leaf_iscomplex(pars):
        if jnp.issubdtype(out_dtype, jnp.complexfloating):  # C -> C
            return _vjp_CC(pars, forward_fn, v, vec, conjugate)
        elif jnp.issubdtype(out_dtype, jnp.floating):  # C -> R
            raise RuntimeError("C->R function detected, but not supported.")
    else:
        if jnp.issubdtype(out_dtype, jnp.complexfloating):  # R -> C
            return _vjp_RC(pars, forward_fn, v, vec, conjugate)
        elif jnp.issubdtype(out_dtype, jnp.floating):  # R -> R
            return _vjp_RR(pars, forward_fn, v, vec, conjugate)
Beispiel #13
0
    def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
        if not FLAGS.jax_enable_x64 and jnp.issubdtype(dtype, np.float64):
            raise SkipTest("can't test float64 agreement")

        bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
        numpy_bits = np.array(1., dtype).view(bits_dtype)
        xla_bits = api.jit(lambda: lax.bitcast_convert_type(
            np.array(1., dtype), bits_dtype))()
        self.assertEqual(numpy_bits, xla_bits)
Beispiel #14
0
def _get_bond_type_parameters(params: Array, bond_type: Array) -> Array:
  """Get parameters for interactions for bonds indexed by a bond-type."""
  # TODO(schsam): We should do better error checking here.
  assert isinstance(bond_type, jnp.ndarray)
  assert len(bond_type.shape) == 1

  if isinstance(params, jnp.ndarray):
    if len(params.shape) == 1:
      return params[bond_type]
    elif len(params.shape) == 0:
      return params
    else:
      raise ValueError(
          'Params must be a scalar or a 1d array if using a bond-type lookup.')
  elif(isinstance(params, int) or isinstance(params, float) or
       jnp.issubdtype(params, jnp.integer) or jnp.issubdtype(params, jnp.floating)):
    return params
  raise NotImplementedError
Beispiel #15
0
def _get_neighborhood_matrix_params(idx: Array, params: Array) -> Array:
  if isinstance(params, jnp.ndarray):
    if len(params.shape) == 1:
      return 0.5 * (jnp.reshape(params, params.shape + (1,)) + params[idx])
    elif len(params.shape) == 2:
      def query(id_a, id_b):
        return params[id_a, id_b]
      query = vmap(vmap(query, (None, 0)))
      return query(jnp.arange(idx.shape[0], dtype=jnp.int32), idx)
    elif len(params.shape) == 0:
      return params
    else:
      raise NotImplementedError()
  elif(isinstance(params, int) or isinstance(params, float) or
       jnp.issubdtype(params, jnp.integer) or jnp.issubdtype(params, jnp.floating)):
    return params
  else:
    raise NotImplementedError 
Beispiel #16
0
    def __init__(self,
                 total_count: Numeric,
                 logits: Optional[Array] = None,
                 probs: Optional[Array] = None,
                 dtype: jnp.dtype = jnp.int_):
        """Initializes a Multinomial distribution.

    Args:
      total_count: The number of trials per sample.
      logits: Logit transform of the probability of each category. Only one
        of `logits` or `probs` can be specified.
      probs: Probability of each category. Only one of `logits` or `probs` can
        be specified.
      dtype:  The type of event samples.
    """
        super().__init__()
        chex.assert_exactly_one_is_none(probs, logits)
        chex.if_args_not_none(chex.assert_axis_dimension_gt,
                              probs,
                              axis=-1,
                              val=1)
        chex.if_args_not_none(chex.assert_axis_dimension_gt,
                              logits,
                              axis=-1,
                              val=1)
        if not (jnp.issubdtype(dtype, jnp.integer)
                or jnp.issubdtype(dtype, jnp.floating)):
            raise ValueError(
                f'The dtype of `{self.name}` must be integer or floating-point, '
                f'instead got `{dtype}`.')

        self._total_count = jnp.asarray(total_count, dtype=dtype)
        self._probs = None if probs is None else math.normalize(probs=probs)
        self._logits = None if logits is None else math.normalize(
            logits=logits)
        self._dtype = dtype

        if self._probs is not None:
            probs_batch_shape = self._probs.shape[:-1]
        else:
            assert self._logits is not None
            probs_batch_shape = self._logits.shape[:-1]
        self._batch_shape = lax.broadcast_shapes(probs_batch_shape,
                                                 self._total_count.shape)
Beispiel #17
0
 def check(x, y):
     if x.dtype.names is not None:
         for deriv in xderiv, yderiv:
             for dim in deriv:
                 if dim not in x.dtype.names:
                     raise ValueError(f'derivative along missing field {dim!r}')
                 if not jnp.issubdtype(x.dtype.fields[dim][0], jnp.number):
                     raise TypeError(f'derivative along non-numeric field {dim!r}')
     elif not xderiv.implicit or not yderiv.implicit:
         raise ValueError('explicit derivatives with non-structured array')
Beispiel #18
0
 def normalize_leaf(data: jnp.ndarray, mean: jnp.ndarray,
                    std: jnp.ndarray) -> jnp.ndarray:
     # Only normalize inexact types.
     if not jnp.issubdtype(data.dtype, jnp.inexact):
         return data
     data = (data - mean) / std
     if max_abs_value is not None:
         # TODO(b/124318564): remove pylint directive
         data = jnp.clip(data, -max_abs_value, +max_abs_value)  # pylint: disable=invalid-unary-operand-type
     return data
Beispiel #19
0
def create_interpolator(points, values):
    if not hasattr(values, "ndim"):
        # allow reasonable duck-typed values
        values = jnp.asarray(values)

    if len(points) > values.ndim:
        raise ValueError("There are %d point arrays, but values has %d "
                         "dimensions" % (len(points), values.ndim))

    if hasattr(values, "dtype") and hasattr(values, "astype"):
        if not jnp.issubdtype(values.dtype, jnp.inexact):
            values = values.astype(float)

    for i, p in enumerate(points):
        if not jnp.all(jnp.diff(p) > 0.0):
            raise ValueError(
                "The points in dimension %d must be strictly ascending" % i)
        if not jnp.asarray(p).ndim == 1:
            raise ValueError(
                "The points in dimension %d must be 1-dimensional" % i)
        if not values.shape[i] == len(p):
            raise ValueError("There are %d points and %d values in "
                             "dimension %d" % (len(p), values.shape[i], i))
    grid = tuple([jnp.asarray(p) for p in points])
    ndim = len(grid)

    def interpolator(xi, method="linear"):
        if method not in ["linear", "nearest"]:
            raise ValueError("Method '%s' is not defined" % method)

        xi = _ndim_coords_from_arrays(xi, ndim)
        if xi.shape[-1] != len(grid):
            raise ValueError("The requested sample points xi have dimension "
                             "%d, but this RegularGridInterpolator has "
                             "dimension %d" % (xi.shape[1], ndim))

        xi_shape = xi.shape
        xi = xi.reshape(-1, xi_shape[-1])

        for i, p in enumerate(xi.T):
            if not jnp.logical_and(jnp.all(grid[i][0] <= p),
                                   jnp.all(p <= grid[i][-1])):
                raise ValueError(
                    "One of the requested xi is out of bounds in dimension %d"
                    % i)

        indices, norm_distances = _find_indices(xi.T, grid)
        if method == "linear":
            result = _evaluate_linear(values, indices, norm_distances)
        elif method == "nearest":
            result = _evaluate_nearest(values, indices, norm_distances)

        return result.reshape(xi_shape[:-1] + values.shape[ndim:])

    return interpolator
Beispiel #20
0
def as_float_array(x: Numeric) -> Array:
    """Converts input to an array with floating-point dtype.

  If the input is already an array with floating-point dtype, it is returned
  unchanged.

  Args:
    x: input to convert.

  Returns:
    An array with floating-point dtype.
  """
    if not isinstance(x, Array):
        x = jnp.asarray(x)
    if jnp.issubdtype(x.dtype, jnp.floating):
        return x
    elif jnp.issubdtype(x.dtype, jnp.integer):
        return x.astype(jnp.float_)
    else:
        raise ValueError(
            f"Expected either floating or integer dtype, got {x.dtype}.")
Beispiel #21
0
 def _param_func_generator(self, data, dist_name, params, batch_shape, func,
                           generate_sample_function=False):
   for param_name, param in params.items():
     if (not tf.is_tensor(param)
         or not np.issubdtype(param.dtype, np.floating)):
       continue
     def _func(param_name, param):
       dist = data.draw(self._make_distribution(
           dist_name, params, batch_shape,
           override_params={param_name: param}))
       return func(dist)
     yield param_name, param, _func
Beispiel #22
0
 def __call__(self, shape: Sequence[int], dtype: Any) -> jnp.ndarray:
     real_dtype = jnp.finfo(dtype).dtype
     m = jax.lax.convert_element_type(self.mean, dtype)
     s = jax.lax.convert_element_type(self.stddev, real_dtype)
     is_complex = jnp.issubdtype(dtype, jnp.complexfloating)
     if is_complex:
         shape = [2, *shape]
     unscaled = jax.random.truncated_normal(hk.next_rng_key(), -2., 2.,
                                            shape, real_dtype)
     if is_complex:
         unscaled = unscaled[0] + 1j * unscaled[1]
     return s * unscaled + m
Beispiel #23
0
def _convert_element_type(operand, new_dtype):
  head, tail = operand
  head = lax.convert_element_type_p.bind(head, new_dtype=new_dtype)
  if tail is not None:
    tail = lax.convert_element_type_p.bind(tail, new_dtype=new_dtype)
  if jnp.issubdtype(new_dtype, jnp.floating):
    if tail is None:
      tail = jnp.zeros_like(head)
  elif tail is not None:
    head = head + tail
    tail = None
  return (head, tail)
Beispiel #24
0
    def testQr(self, shape, dtype, full_matrices, rng_factory):
        rng = rng_factory()
        _skip_if_unsupported_type(dtype)
        if (np.issubdtype(dtype, onp.complexfloating)
                and (jtu.device_under_test() == "tpu" or jax.lib.version <=
                     (0, 1, 27))):
            raise unittest.SkipTest("No complex QR implementation")
        m, n = shape[-2:]

        if full_matrices:
            mode, k = "complete", m
        else:
            mode, k = "reduced", min(m, n)

        a = rng(shape, dtype)
        lq, lr = np.linalg.qr(a, mode=mode)

        # onp.linalg.qr doesn't support batch dimensions. But it seems like an
        # inevitable extension so we support it in our version.
        nq = onp.zeros(shape[:-2] + (m, k), dtype)
        nr = onp.zeros(shape[:-2] + (k, n), dtype)
        for index in onp.ndindex(*shape[:-2]):
            nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode)

        max_rank = max(m, n)

        # Norm, adjusted for dimension and type.
        def norm(x):
            n = onp.linalg.norm(x, axis=(-2, -1))
            return n / (max_rank * np.finfo(dtype).eps)

        def compare_orthogonal(q1, q2):
            # Q is unique up to sign, so normalize the sign first.
            sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True)
            phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios))
            q1 *= phases
            self.assertTrue(onp.all(norm(q1 - q2) < 30))

        # Check a ~= qr
        self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30))

        # Compare the first 'k' vectors of Q; the remainder form an arbitrary
        # orthonormal basis for the null space.
        compare_orthogonal(nq[..., :k], lq[..., :k])

        # Check that q is close to unitary.
        self.assertTrue(
            onp.all(norm(onp.eye(k) - onp.matmul(onp.conj(T(lq)), lq)) < 5))

        if not full_matrices and m >= n:
            jtu.check_jvp(np.linalg.qr,
                          partial(jvp, np.linalg.qr), (a, ),
                          atol=3e-3)
Beispiel #25
0
    def test_fast_eval_shape_inside_transform(self, module_fn: ModuleFn, shape,
                                              dtype):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        m = module_fn()
        m_slow = hk.eval_shape(m, x)
        m_fast = hk.experimental.fast_eval_shape(m, x)
        self.assertEqual(m_slow, m_fast)
Beispiel #26
0
def _make_rotate_left(dtype):
  if not jnp.issubdtype(dtype, np.integer):
    raise TypeError("_rotate_left only accepts integer dtypes.")
  nbits = np.array(jnp.iinfo(dtype).bits, dtype)

  def _rotate_left(x, d):
    if lax.dtype(d) != dtype:
      d = lax.convert_element_type(d, dtype)
    if lax.dtype(x) != dtype:
      x = lax.convert_element_type(x, dtype)
    return lax.shift_left(x, d) | lax.shift_right_logical(x, nbits - d)
  return _rotate_left
Beispiel #27
0
 def testEighBatching(self, shape, dtype, rng_factory):
   rng = rng_factory()
   _skip_if_unsupported_type(dtype)
   if (jtu.device_under_test() == "tpu" and
       np.issubdtype(dtype, onp.complexfloating)):
     raise unittest.SkipTest("No complex eigh on TPU")
   shape = (10,) + shape
   args = rng(shape, dtype)
   args = (args + onp.conj(T(args))) / 2
   ws, vs = vmap(jsp.linalg.eigh)(args)
   self.assertTrue(onp.all(onp.linalg.norm(
       onp.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
Beispiel #28
0
def _convert_to_tensor(value, dtype=None, dtype_hint=None, name=None):  # pylint: disable=unused-argument
    """Emulates tf.convert_to_tensor."""
    assert not tf.is_tensor(value), value
    if isinstance(value, np.ndarray):
        if dtype is not None:
            dtype = utils.numpy_dtype(dtype)
            # if np.result_type(value, dtype) != dtype:
            #   raise ValueError('Expected dtype {} but got {} with dtype {}.'.format(
            #       dtype, value, value.dtype))
            return value.astype(dtype)
        return value
    if isinstance(value, TensorShape):
        value = [int(d) for d in value.as_list()]
    if dtype is None and dtype_hint is not None:
        dtype_hint = utils.numpy_dtype(dtype_hint)
        value = np.array(value)
        if np.size(value):
            # Match TF behavior, which won't downcast e.g. float to int.
            if np.issubdtype(value.dtype, np.complexfloating):
                if not np.issubdtype(dtype_hint, np.complexfloating):
                    return value
            if np.issubdtype(value.dtype, np.floating):
                if not np.issubdtype(dtype_hint, np.floating):
                    return value
            if np.issubdtype(value.dtype, np.integer):
                if not np.issubdtype(dtype_hint, np.integer):
                    return value
        return value.astype(dtype_hint)
    return np.array(value, dtype=utils.numpy_dtype(dtype or dtype_hint))
Beispiel #29
0
    def __call__(
        self,
        ids: jnp.ndarray,
        lookup_style: Optional[Union[str, hk.EmbedLookupStyle]] = None,
        precision: Optional[jax.lax.Precision] = None,
    ) -> jnp.ndarray:
        r"""Lookup embeddings.

    Looks up an embedding vector for each value in ``ids``. All ids must be
    within ``[0, vocab_size)`` to prevent ``NaN``\ s from propagating.

    Args:
      ids: integer array.
      lookup_style: Overrides the ``lookup_style`` given in the constructor.
      precision: Overrides the ``precision`` given in the constructor.

    Returns:
      Tensor of ``ids.shape + [embedding_dim]``.

    Raises:
      AttributeError: If ``lookup_style`` is not valid.
      ValueError: If ``ids`` is not an integer array.
    """
        # TODO(tomhennigan) Consider removing asarray here.
        ids = jnp.asarray(ids)
        if not jnp.issubdtype(ids.dtype, jnp.integer):
            raise ValueError(
                "hk.Embed's __call__ method must take an array of "
                "integer dtype but was called with an array of "
                f"{ids.dtype}")

        lookup_style = lookup_style or self.lookup_style
        if isinstance(lookup_style, str):
            lookup_style = getattr(hk.EmbedLookupStyle, lookup_style.upper())

        if lookup_style == hk.EmbedLookupStyle.ARRAY_INDEX:
            # If you don't wrap ids in a singleton tuple then JAX will try to unpack
            # it along the row dimension and treat each row as a separate index into
            # one of the dimensions of the array. The error only surfaces when
            # indexing with DeviceArray, while indexing with numpy.ndarray works fine.
            # See https://github.com/google/jax/issues/620 for more details.
            # Cast to a jnp array in case `ids` is a tracer (eg un a dynamic_unroll).
            return jnp.asarray(self.embeddings)[(ids, )]

        elif lookup_style == hk.EmbedLookupStyle.ONE_HOT:
            one_hot_ids = jax.nn.one_hot(ids, self.vocab_size)
            precision = self.precision if precision is None else precision
            return jnp.dot(one_hot_ids, self.embeddings, precision=precision)

        else:
            raise NotImplementedError(
                f"{lookup_style} is not supported by hk.Embed.")
Beispiel #30
0
 def testEigvalsh(self, shape, dtype, rng_factory):
   rng = rng_factory()
   _skip_if_unsupported_type(dtype)
   if jtu.device_under_test() == "tpu":
     if np.issubdtype(dtype, np.complexfloating):
       raise unittest.SkipTest("No complex eigh on TPU")
   n = shape[-1]
   def args_maker():
     a = rng((n, n), dtype)
     a = (a + onp.conj(a.T)) / 2
     return [a]
   self._CheckAgainstNumpy(onp.linalg.eigvalsh, np.linalg.eigvalsh, args_maker,
                           check_dtypes=True, tol=1e-3)