Exemple #1
0
  def test_dropout(self, model, width, same_inputs, is_ntk, padding, strides,
                   filter_shape, phi, use_pooling, proj_into_2d):
    if xla_bridge.get_backend().platform == 'tpu' and same_inputs:
      raise jtu.SkipTest(
          'Skip TPU test for `same_inputs`. Need to handle '
          'random keys carefully for dropout + empirical kernel.')

    pool_type = 'AVG'
    use_dropout = True
    is_conv = 'conv' in model
    is_res = False
    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    W_std, b_std = 2.**0.5, 0.5**0.5
    layer_norm = None
    parameterization = 'ntk'
    if is_conv:
      if xla_bridge.get_backend().platform == 'cpu':
        raise jtu.SkipTest('Not running CNN models on CPU to save time.')

      if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or
                                  (padding == 'VALID' and filter_shape !=
                                   (1, 1)))):
        raise jtu.SkipTest('Different paths in a residual models need to return'
                           ' outputs of the same shape.')
    elif (filter_shape != FILTER_SHAPES[0] or padding != PADDINGS[0] or
          strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or
          use_pooling):
      raise jtu.SkipTest('FC models do not have these parameters.')

    net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
                   padding, phi, strides, width, is_ntk, proj_into_2d,
                   pool_type, layer_norm, parameterization, use_dropout)
    self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout,
                                         is_ntk, proj_into_2d)
Exemple #2
0
  def test_exact(self, model, width, strides, padding, phi, same_inputs,
                 filter_size, use_pooling, is_ntk, is_res, proj_into_2d):
    is_conv = 'conv' in model

    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    if is_conv:
      if xla_bridge.get_backend().platform == 'cpu':
        raise jtu.SkipTest('Not running CNN models on CPU to save time.')

      if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or
                                  (padding == 'VALID' and filter_size !=
                                   (1, 1)))):
        raise jtu.SkipTest('Different paths in a residual models need to return'
                           ' outputs of the same shape.')
    elif (filter_size != FILTER_SIZES[0] or padding != PADDINGS[0] or
          strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or
          use_pooling):
      raise jtu.SkipTest('FC models do not have these parameters.')

    if (proj_into_2d.startswith('ATTN') and strides == (2, 1) and
        padding == 'VALID' and xla_bridge.get_backend().platform == 'tpu'):
      #TODO: speed up the vmap alternative impl or fix the current one
      raise jtu.SkipTest('ATTN forward pass on TPU is broken if one of'
                         ' the spatial dimensions is singleton.')

    W_std, b_std = 2.**0.5, 0.5**0.5
    layer_norm = None

    self._check_agreement_with_empirical(W_std, b_std, filter_size, is_conv,
                                         is_ntk, is_res, layer_norm, padding,
                                         phi, proj_into_2d, same_inputs,
                                         strides, use_pooling, width)
Exemple #3
0
def from_dlpack(dlpack, backend=None):
    """Returns a `DeviceArray` representation of a DLPack tensor `dlpack`.

  The returned `DeviceArray` shares memory with `dlpack`.

  Args:
    dlpack: a DLPack tensor, on either CPU or GPU.
    backend: deprecated, do not use.
  """
    if jax.lib._xla_extension_version >= 25:
        cpu_backend = xla_bridge.get_backend("cpu")
        try:
            gpu_backend = xla_bridge.get_backend("gpu")
        except RuntimeError:
            gpu_backend = None
        buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
            dlpack, cpu_backend, gpu_backend)
    else:
        # TODO(phawkins): drop the backend argument after deleting this case.
        backend = backend or xla_bridge.get_backend()
        client = getattr(backend, "client", backend)
        buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client)

    xla_shape = buf.xla_shape()
    assert not xla_shape.is_tuple()
    aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
    return xla.make_device_array(aval, buf.device(), buf)  # pytype: disable=attribute-error
Exemple #4
0
  def test_exact(self, model, width, strides, padding, phi, same_inputs,
                 filter_size, use_pooling, is_ntk, is_res, proj_into_2d):
    is_conv = 'conv' in model

    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    if is_conv:
      if xla_bridge.get_backend().platform == 'cpu':
        raise jtu.SkipTest('Not running CNN models on CPU to save time.')

      if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or
                                  (padding == 'VALID' and filter_size !=
                                   (1, 1)))):
        raise jtu.SkipTest('Different paths in a residual models need to return'
                           ' outputs of the same shape.')
    elif (filter_size != FILTER_SIZES[0] or padding != PADDINGS[0] or
          strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or
          use_pooling):
      raise jtu.SkipTest('FC models do not have these parameters.')

    if (proj_into_2d.startswith('ATTN') and strides == (2, 1) and
        padding == 'VALID' and xla_bridge.get_backend().platform == 'tpu'):
      #TODO(jirihron): speed up the vmap alternative impl or fix the current one
      raise jtu.SkipTest('ATTN forward pass on TPU is broken if one of'
                         ' the spatial dimensions is singleton.')

    W_std, b_std = 2.**0.5, 0.5**0.5

    key = random.PRNGKey(1)
    x1, x2 = _get_inputs(key, is_conv, same_inputs, INPUT_SHAPE)

    init_fn, apply_fn, kernel_fn = _get_net(W_std, b_std, filter_size,
                                            is_conv, use_pooling, is_res,
                                            padding, phi, strides, width,
                                            is_ntk, proj_into_2d)

    def _get_empirical(n_samples, get):
      kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn(
          init_fn, apply_fn, key, n_samples)
      return kernel_fn_empirical(x1, x2, get)

    if proj_into_2d == 'ATTN_PARAM':
      # no analytic kernel available, just test forward/backward pass
      _get_empirical(1, 'ntk' if is_ntk else 'nngp')
    else:
      if is_ntk:
        exact = kernel_fn(x1, x2, 'ntk')
        empirical = np.reshape(_get_empirical(N_SAMPLES, 'ntk'), exact.shape)
      else:
        exact = kernel_fn(x1, x2, 'nngp')
        empirical = _get_empirical(N_SAMPLES, 'nngp')
      utils.assert_close_matrices(self, empirical, exact, RTOL)
Exemple #5
0
    def test_convert_scalars(self):
        # TODO(jblespiau): Remove when the version is out.
        if jaxlib.version < (0, 1, 53):
            return

        jax_jit = jaxlib.jax_jit

        jax_enable_x64 = FLAGS.jax_enable_x64

        if jax_enable_x64:
            int_type = np.int64
            float_type = np.float64
            complex_type = np.complex128
        else:
            int_type = np.int32
            float_type = np.float32
            complex_type = np.complex64

        # int
        res = jax_jit._ScalarToBuffer(1, jax_enable_x64,
                                      xla_bridge.get_backend()).to_py()
        self.assertEqual(res, 1)
        self.assertEqual(res.dtype, int_type)
        # We also compare to the Python Jax API, to make sure we have the exact
        # same behavior. When Jax removes the flag and removes this feature, this
        # test will fail.
        self.assertEqual(jnp.asarray(1).dtype, res.dtype)

        # float
        res = jax_jit._ScalarToBuffer(1.0, jax_enable_x64,
                                      xla_bridge.get_backend()).to_py()
        self.assertEqual(res, 1.0)
        self.assertEqual(res.dtype, float_type)
        self.assertEqual(jnp.asarray(1.0).dtype, res.dtype)

        # bool
        for bool_value in [True, False]:
            res = jax_jit._ScalarToBuffer(bool_value, jax_enable_x64,
                                          xla_bridge.get_backend()).to_py()
            self.assertEqual(res, np.asarray(bool_value))
            self.assertEqual(res.dtype, np.bool)
            self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype)

        # Complex
        res = jax_jit._ScalarToBuffer(1 + 1j, jax_enable_x64,
                                      xla_bridge.get_backend()).to_py()
        self.assertEqual(res, 1 + 1j)
        self.assertEqual(res.dtype, complex_type)
        self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)
Exemple #6
0
def static_cast(*xs):
    """Function to cast a value to the lowest dtype that can express it."""
    # NOTE(schsam): static_cast is so named because it cannot be jit.
    if xla_bridge.get_backend().platform == 'tpu':
        return (np.array(x, np.float32) for x in xs)
    else:
        return (np.array(x, dtype=onp.min_scalar_type(x)) for x in xs)
Exemple #7
0
    def testIsOnCPU(self):
        for dtype in [np.float32, np.float64]:
            with self.subTest(dtype=dtype):

                def x():
                    return random.normal(random.PRNGKey(1), (2, 3), dtype)

                def x_cpu():
                    return device_get(
                        random.normal(random.PRNGKey(1), (2, 3), dtype))

                x_jit = jit(x)
                # x_cpu_jit = jit(x_cpu)
                x_cpu_jit_cpu = jit(x_cpu, backend='cpu')

                self.assertTrue(utils.is_on_cpu(x_cpu()))
                # TODO(mattjj): re-enable this when device_put under jit works
                # self.assertTrue(utils.is_on_cpu(x_cpu_jit()))
                self.assertTrue(utils.is_on_cpu(x_cpu_jit_cpu()))

                if xla_bridge.get_backend().platform == 'cpu':
                    self.assertTrue(utils.is_on_cpu(x()))
                    self.assertTrue(utils.is_on_cpu(x_jit()))
                else:
                    self.assertFalse(utils.is_on_cpu(x()))
                    self.assertFalse(utils.is_on_cpu(x_jit()))
    def testIsOnCPU(self):
        for dtype in [np.float32, np.float64]:
            with self.subTest(dtype=dtype):

                def x():
                    return random.normal(random.PRNGKey(1), (2, 3), dtype)

                def x_cpu():
                    return device_get(
                        random.normal(random.PRNGKey(1), (2, 3), dtype))

                x_jit = jit(x)
                x_cpu_jit = jit(x_cpu)
                x_cpu_jit_cpu = jit(x_cpu, backend='cpu')

                self.assertTrue(predict._is_on_cpu(x_cpu()))
                self.assertTrue(predict._is_on_cpu(x_cpu_jit()))
                self.assertTrue(predict._is_on_cpu(x_cpu_jit_cpu()))

                if xla_bridge.get_backend().platform == 'cpu':
                    self.assertTrue(predict._is_on_cpu(x()))
                    self.assertTrue(predict._is_on_cpu(x_jit()))
                else:
                    self.assertFalse(predict._is_on_cpu(x()))
                    self.assertFalse(predict._is_on_cpu(x_jit()))
Exemple #9
0
    def test_layernorm(self, model, width, same_inputs, is_ntk, proj_into_2d,
                       layer_norm):
        is_conv = 'conv' in model
        # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
        if is_conv:
            if xla_bridge.get_backend().platform == 'cpu':
                raise jtu.SkipTest(
                    'Not running CNN models on CPU to save time.')
        elif proj_into_2d != PROJECTIONS[0] or layer_norm != LAYER_NORM[0]:
            raise jtu.SkipTest('FC models do not have these parameters.')

        W_std, b_std = 2.**0.5, 0.5**0.5
        filter_size = FILTER_SIZES[0]
        padding = PADDINGS[0]
        strides = STRIDES[0]
        phi = stax.Relu()
        use_pooling, is_res = False, False
        parameterization = 'ntk'
        use_dropout = False

        self._check_agreement_with_empirical(W_std, b_std, filter_size,
                                             is_conv, is_ntk, is_res,
                                             layer_norm, padding, phi,
                                             proj_into_2d, same_inputs,
                                             strides, use_pooling, width,
                                             parameterization, use_dropout)
Exemple #10
0
def device_memory_profile(backend: Optional[str] = None) -> bytes:
    """Captures a JAX device memory profile as ``pprof``-format protocol buffer.

  A device memory profile is a snapshot of the state of memory, that describes the JAX
  :class:`jax.DeviceArray` and executable objects present in memory and their
  allocation sites.

  For more information how to use the device memory profiler, see
  :doc:`/device_memory_profiling`.

  The profiling system works by instrumenting JAX on-device allocations,
  capturing a Python stack trace for each allocation. The instrumentation is
  always enabled; :func:`device_memory_profile` provides an API to capture it.

  The output of :func:`device_memory_profile` is a binary protocol buffer that
  can be interpreted and visualized by the `pprof tool
  <https://github.com/google/pprof>`_.

  Args:
    backend: optional; the name of the JAX backend for which the device memory
      profile should be collected.

  Returns:
    A byte string containing a binary `pprof`-format protocol buffer.
  """
    return xla_client.heap_profile(xla_bridge.get_backend(backend))
Exemple #11
0
 def test_gpu_translation_rule(self):
   version = xla_bridge.get_backend().platform_version
   cuda_version = None if version == "<unknown>" else int(version.split()[-1])
   if cuda_version is None or cuda_version < 11000:
     self.assertNotIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"])
   else:
     self.assertIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"])
Exemple #12
0
  def test_parameterizations(self, model, width, same_inputs, is_ntk,
                             filter_shape, proj_into_2d, parameterization):
    is_conv = 'conv' in model

    W_std, b_std = 2.**0.5, 0.5**0.5
    padding = PADDINGS[0]
    strides = STRIDES[0]
    phi = stax.Relu()
    use_pooling, is_res = False, False
    layer_norm = None
    pool_type = 'AVG'
    use_dropout = False

    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    if is_conv:
      if xla_bridge.get_backend().platform == 'cpu':
        raise jtu.SkipTest('Not running CNN models on CPU to save time.')
    elif proj_into_2d != PROJECTIONS[0]:
      raise jtu.SkipTest('FC models do not have these parameters.')

    net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
                   padding, phi, strides, width, is_ntk, proj_into_2d,
                   pool_type, layer_norm, parameterization, use_dropout)
    self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout,
                                         is_ntk, proj_into_2d)
  def testPredictOnCPU(self):
    x_train = random.normal(random.PRNGKey(1), (4, 4, 4, 2))
    x_test = random.normal(random.PRNGKey(1), (8, 4, 4, 2))

    y_train = random.uniform(random.PRNGKey(1), (4, 2))

    _, _, kernel_fn = stax.serial(
        stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1))

    for store_on_device in [False, True]:
      for device_count in [0, 1]:
        for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
          for x in [None, 'x_test']:
            with self.subTest(
                store_on_device=store_on_device,
                device_count=device_count,
                get=get,
                x=x):
              kernel_fn_batched = batch.batch(kernel_fn, 2, device_count,
                                              store_on_device)
              predictor = predict.gradient_descent_mse_ensemble(
                  kernel_fn_batched, x_train, y_train)

              x = x if x is None else x_test
              predict_none = predictor(None, x, get, compute_cov=True)
              predict_inf = predictor(np.inf, x, get, compute_cov=True)
              self.assertAllClose(predict_none, predict_inf)

              if x is not None:
                on_cpu = (not store_on_device or
                          xla_bridge.get_backend().platform == 'cpu')
                self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf))
                self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
def train_loop(key,
               init_params,
               loss_fn,
               parallel=True,
               summarize_fn=default_summarize,
               lr=1e-4,
               num_steps=int(1e5),
               summarize_every=100,
               checkpoint_every=5000,
               clobber_checkpoint=False,
               logdir="/tmp/lda_inference"):

    if not parallel:
        train_fn = local_train_loop
    elif parallel and can_train_parallel():
        train_fn = parallel_train_loop
    else:
        print(
            "Platform is %s and num devices is %d, defaulting to local training."
            % (xla_bridge.get_backend().platform, len(xla_bridge.devices())))
        train_fn = local_train_loop

    train_fn(key,
             init_params,
             loss_fn,
             summarize_fn=summarize_fn,
             lr=lr,
             num_steps=num_steps,
             summarize_every=summarize_every,
             checkpoint_every=checkpoint_every,
             clobber_checkpoint=clobber_checkpoint,
             logdir=logdir)
Exemple #15
0
def _gamma_grad(sample, a):
    samples = np.reshape(sample, -1)
    alphas = np.reshape(a, -1)
    if xla_bridge.get_backend().platform == 'cpu':
        grads = lax.map(lambda args: _gamma_grad_one(*args), (samples, alphas))
    else:
        grads = vmap(_gamma_grad_one)(samples, alphas)
    return grads.reshape(onp.shape(a))
    def testNTKMeanPrediction(self, train_shape, test_shape, network,
                              out_logits):

        key = random.PRNGKey(0)

        key, split = random.split(key)
        data_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        data_labels = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        data_test = np.cos(random.normal(split, test_shape))
        _, _, ker_fun = _build_network(train_shape[1:], network, out_logits)
        mean_pred, var = predict.gp_inference(ker_fun,
                                              data_train,
                                              data_labels,
                                              data_test,
                                              diag_reg=0.,
                                              mode='NTK',
                                              compute_var=True)

        if xla_bridge.get_backend().platform == 'tpu':
            eigh = np.onp.linalg.eigh
        else:
            eigh = np.linalg.eigh

        self.assertEqual(var.shape[0], data_test.shape[0])
        min_eigh = np.min(eigh(var)[0])
        self.assertGreater(min_eigh + 1e-10, 0.)

        def mc_sampling(count=10):
            empirical_mean = 0.
            key = random.PRNGKey(100)
            for _ in range(count):
                key, split = random.split(key)
                params, f, theta = _empirical_kernel(split, train_shape[1:],
                                                     network, out_logits)
                g_dd = theta(data_train, None)
                g_td = theta(data_test, data_train)
                predictor = predict.gradient_descent_mse(
                    g_dd, data_labels, g_td)

                fx_initial_train = f(params, data_train)
                fx_initial_test = f(params, data_test)

                _, fx_pred_test = predictor(1.0e8, fx_initial_train,
                                            fx_initial_test)
                empirical_mean += fx_pred_test
            return empirical_mean / count

        atol = ATOL
        rtol = RTOL
        mean_emp = mc_sampling(100)

        self.assertAllClose(mean_pred, mean_emp, True, rtol, atol)
Exemple #17
0
    def testPredictOnCPU(self):
        key1 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key2 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key3 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        x_train = np.asarray(normal((4, 4, 4, 2), seed=key1))
        x_test = np.asarray(normal((8, 4, 4, 2), seed=key2))

        y_train = np.asarray(stateless_uniform(shape=(4, 2), seed=key3))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    for x in [None, 'x_test']:
                        with self.subTest(store_on_device=store_on_device,
                                          device_count=device_count,
                                          get=get,
                                          x=x):
                            kernel_fn_batched = batch.batch(
                                kernel_fn, 2, device_count, store_on_device)
                            predictor = predict.gradient_descent_mse_ensemble(
                                kernel_fn_batched, x_train, y_train)

                            x = x if x is None else x_test
                            predict_none = predictor(None,
                                                     x,
                                                     get,
                                                     compute_cov=True)
                            predict_inf = predictor(np.inf,
                                                    x,
                                                    get,
                                                    compute_cov=True)
                            self.assertAllClose(predict_none, predict_inf)

                            if x is not None:
                                on_cpu = (not store_on_device
                                          or xla_bridge.get_backend().platform
                                          == 'cpu')
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_inf))
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_none))
Exemple #18
0
  def _res_tf_to_jax(res_tf):
    if isinstance(res_tf, tf.Tensor) and res_tf.dtype in dlpack.SUPPORTED_DTYPES:
      res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type
      res_jax_platform = res_tf_platform.lower()
      if res_jax_platform in _DLPACK_PLATFORMS:
        res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
        return jax.dlpack.from_dlpack(
            res_dlpack, backend=xla_bridge.get_backend(res_jax_platform))

    return _device_put_raw(np.asarray(res_tf))
Exemple #19
0
def _binomial(key, p, n, shape):
    shape = shape or lax.broadcast_shapes(np.shape(p), np.shape(n))
    # reshape to map over axis 0
    p = np.reshape(np.broadcast_to(p, shape), -1)
    n = np.reshape(np.broadcast_to(n, shape), -1)
    key = random.split(key, np.size(p))
    if xla_bridge.get_backend().platform == 'cpu':
        ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
    else:
        ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
    return np.reshape(ret, shape)
Exemple #20
0
def _eigh(mat):
    """Platform specific eigh."""
    # TODO(schsam): Eventually, we may want to handle non-symmetric kernels for
    # e.g. masking. Additionally, once JAX supports eigh on TPU, we probably want
    # to switch to JAX's eigh.
    if xla_bridge.get_backend().platform == 'tpu':
        eigh = np.onp.linalg.eigh
    else:
        eigh = np.linalg.eigh
        eigh = jit(eigh, backend='cpu') if _is_on_cpu(mat) else jit(eigh)
    return eigh(mat)
Exemple #21
0
    def testTorchToJaxFailure(self):
        x = torch.arange(6).reshape((2, 3))
        y = torch.utils.dlpack.to_dlpack(x[:, :2])

        backend = xla_bridge.get_backend()
        client = getattr(backend, "client", backend)

        regex_str = (
            r'Unimplemented: Only DLPack tensors with trivial \(compact\) '
            r'striding are supported')
        with self.assertRaisesRegex(RuntimeError, regex_str):
            xla_client._xla.dlpack_managed_tensor_to_buffer(y, client)
Exemple #22
0
def stub_out_pmap(batch, count):
    # If we are using GPU or CPU stub out pmap with vmap to simulate multi-core.
    if count > 0:

        class xla_bridge_stub(object):
            def device_count(self):
                return count

        platform = xla_bridge.get_backend().platform
        if platform == 'gpu' or platform == 'cpu':
            batch.pmap = _jit_vmap
            batch.xla_bridge = xla_bridge_stub()
Exemple #23
0
def _poisson(key, rate, shape, dtype):
    # Ref: https://en.wikipedia.org/wiki/Poisson_distribution#Generating_Poisson-distributed_random_variables
    shape = shape or np.shape(rate)
    rate = lax.convert_element_type(rate, canonicalize_dtype(np.float64))
    rate = np.broadcast_to(rate, shape)
    rng_keys = random.split(key, np.size(rate))
    if xla_bridge.get_backend().platform == 'cpu':
        k = lax.map(_poisson_one, (rng_keys, np.reshape(rate, -1)))
    else:
        k = vmap(_poisson_one)((rng_keys, np.reshape(rate, -1)))
    k = lax.convert_element_type(k, dtype)
    return np.reshape(k, shape)
Exemple #24
0
def _gamma_impl(key, a):
    a_shape = np.shape(a)
    # split key to match the shape of a
    key_ndim = np.ndim(key) - 1
    key = np.reshape(key, (-1, 2))
    key = vmap(split, in_axes=(0, None))(key, prod(a_shape[key_ndim:]))
    keys = np.reshape(key, (-1, 2))
    alphas = np.reshape(a, -1)
    if xla_bridge.get_backend().platform == 'cpu':
        samples = lax.map(lambda args: _gamma_one(*args), (keys, alphas))
    else:
        samples = vmap(_gamma_one)(keys, alphas)
    return np.reshape(samples, a_shape),
    def test_conv_local_general_dilated(self, n, padding, lhs_spec, rhs_spec,
                                        out_spec):
        """Make sure LCN with tiled CNN kernel matches CNN."""
        if xla_bridge.get_backend().platform == 'cpu' and n > 1:
            raise absltest.SkipTest('Skipping large tests on CPU.')

        lhs_spec_default = 'NCHWDX'[:n + 2]
        rhs_spec_default = 'OIHWDX'[:n + 2]

        lhs_default = random.normal(random.PRNGKey(1),
                                    (2, 4, 7, 6, 5, 8)[:n + 2])
        rhs_default = random.normal(random.PRNGKey(2),
                                    (3, 4, 2, 3, 1, 2)[:n + 2])

        window_strides = (1, 2, 3, 4)[:n]
        rhs_dilation = (2, 1, 3, 2)[:n]

        lhs_perm = [lhs_spec_default.index(c) for c in lhs_spec]
        lhs = np.transpose(lhs_default, lhs_perm)

        rhs_perm = [rhs_spec_default.index(c) for c in rhs_spec]
        rhs = np.transpose(rhs_default, rhs_perm)

        kwargs = dict(lhs=lhs,
                      window_strides=window_strides,
                      padding=padding,
                      rhs_dilation=rhs_dilation,
                      dimension_numbers=(lhs_spec, rhs_spec, out_spec))

        out_conv = lax.conv_general_dilated(rhs=rhs, **kwargs)

        rhs_local = np.moveaxis(rhs,
                                (rhs_spec.index('O'), rhs_spec.index('I')),
                                (0, 1))
        rhs_local = rhs_local.reshape((rhs_local.shape[0], -1) + (1, ) * n)

        rhs_shape = (rhs_local.shape[:2] +
                     tuple(out_conv.shape[out_spec.index(c)]
                           for c in rhs_spec_default[2:]))

        rhs_local = np.broadcast_to(rhs_local, rhs_shape)
        rhs_local = np.transpose(rhs_local, rhs_perm)

        filter_shape = [
            rhs.shape[i] for i in range(n + 2) if rhs_spec[i] not in ('O', 'I')
        ]
        out_local = utils.conv_local_general_dilated(rhs=rhs_local,
                                                     filter_shape=filter_shape,
                                                     **kwargs)

        self.assertAllClose(out_conv, out_local, atol=1e-5, rtol=1e-5)
Exemple #26
0
def _replicate(x, devices=None):
  x = jax.numpy.array(x)
  if devices is None:
    # match the default device assignments used in pmap:
    # for single-host, that's the XLA default device assignment
    # for multi-host, it's the order of jax.local_devices()
    if jax.host_count() == 1:
      devices = [d for d in xb.get_backend().get_default_device_assignment(
          jax.device_count()) if d.host_id == jax.host_id()]
    else:
      devices = jax.local_devices()
  aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype)
  buffers = [jax.interpreters.xla.device_put(x, device=d) for d in devices]
  return jax.pxla.ShardedDeviceArray(aval, buffers)
Exemple #27
0
    def test_exact(self, model, width, strides, padding, phi, same_inputs,
                   filter_size, use_pooling, is_ntk, is_res):
        is_conv = 'conv' in model

        # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
        if is_conv:
            if xla_bridge.get_backend().platform == 'cpu':
                raise jtu.SkipTest(
                    'Not running CNN models on CPU to save time.')

            if use_pooling and not same_inputs:
                raise jtu.SkipTest(
                    'Pooling layers for different inputs or for same '
                    'padding not implemented.')

            if (is_res and is_conv
                    and ((strides is not None and strides !=
                          (1, 1)) or (padding == 'VALID' and filter_size !=
                                      (1, 1)))):
                raise jtu.SkipTest(
                    'Different paths in a residual models need to return'
                    ' outputs of the same shape.')
        elif (filter_size != FILTER_SIZES[0] or padding != PADDINGS[0]
              or strides != STRIDES[0] or use_pooling):
            raise jtu.SkipTest('FC models do not have these parameters.')

        W_std, b_std = 2.**0.5, 0.5**0.5

        key = random.PRNGKey(1)
        x1, x2 = _get_inputs(key, is_conv, same_inputs, INPUT_SHAPE)

        init_fun, apply_fun, ker_fun = _get_net(W_std, b_std, filter_size,
                                                is_conv, use_pooling, is_res,
                                                padding, phi, strides, width,
                                                is_ntk)

        if is_ntk:
            exact = ker_fun(x1, x2).ntk
            ker_fun_empirical = monte_carlo.get_ker_fun_monte_carlo(
                init_fun, apply_fun, False, True)
            empirical = ker_fun_empirical(x1, x2, key, N_SAMPLES).ntk
            empirical = np.reshape(empirical, exact.shape)
        else:
            exact = ker_fun(x1, x2, compute_ntk=False).nngp
            ker_fun_empirical = monte_carlo.get_ker_fun_monte_carlo(
                init_fun, apply_fun, True, False)
            empirical = ker_fun_empirical(x1, x2, key, N_SAMPLES).nngp

        utils.assert_close_matrices(self, empirical, exact, RTOL)
Exemple #28
0
def _dynamic_xla_call_impl(*args, jaxpr, num_consts):
  in_dim_vals, consts, args = split_list(args, [len(jaxpr.in_dim_binders), num_consts])
  dim_in_avals = [v.aval for v in jaxpr.in_dim_binders]
  c = xb.make_computation_builder("dxla_call")
  dim_params, params = _make_params(c, dim_in_avals, map(xla.abstractify, args))
  const_params = _xla_consts(c, consts)
  dim_outs, outs = djaxpr_subcomp(c, jaxpr, dim_params, const_params + params)
  out = xops.Tuple(c, [o for ops in dim_outs + outs for o in ops])
  compiled = xb.get_backend(None).compile(c.build(out))
  result_handlers = map(result_handler, [v.aval for v in jaxpr.outs])
  out_bufcounts = [v.aval._num_buffers for v in jaxpr.outs]
  partitioner = result_partitioner(jaxpr.in_dim_binders, in_dim_vals,
                                   jaxpr.out_dims, out_bufcounts)
  return execute_compiled(compiled, partitioner, result_handlers,
                          in_dim_vals, args)
Exemple #29
0
def stub_out_pmap(batch, count):
    # If we are using GPU or CPU stub out pmap with vmap to simulate multi-core.
    if count > 1:

        class xla_bridge_stub(object):
            def device_count(self):
                return count

        platform = xla_bridge.get_backend().platform
        if platform == 'gpu' or platform == 'cpu':
            # TODO(romann): investigate why vmap is extremely slow in
            # `utils/monte_carlo_test.py`, `test_monte_carlo_vs_analytic`.
            # Example: http://sponge/e081c176-e77f-428c-846d-bafbfd86a46c
            batch.pmap = vmap
            batch.xla_bridge = xla_bridge_stub()
Exemple #30
0
    def _res_tf_to_jax(res_tf: TfVal, out_aval: core.AbstractValue):
        res_tf, _ = jax2tf_internal._tfval_to_tensor_jax_dtype(
            res_tf, jax_dtype=out_aval.dtype)
        if isinstance(res_tf,
                      tf.Tensor) and res_tf.dtype in dlpack.SUPPORTED_DTYPES:
            res_tf_platform = tf.DeviceSpec.from_string(
                res_tf.backing_device).device_type
            res_jax_platform = res_tf_platform.lower()
            if res_jax_platform in _DLPACK_PLATFORMS:
                res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
                return jax.dlpack.from_dlpack(
                    res_dlpack,
                    backend=xla_bridge.get_backend(res_jax_platform))

        return jnp.asarray(np.asarray(res_tf))