예제 #1
0
 def testNested(self):
     f = lambda x: lax_parallel.psum(lax_parallel.psum(x, 'i'), 'j')
     x = onp.ones((2, 2))
     ans1 = serial_pmap(serial_pmap(f, 'i'), 'j')(x)
     ans2 = serial_pmap(serial_pmap(f, 'j'), 'i')(x)
     expected = 4 * onp.ones((2, 2))
     self.assertAllClose(ans1, expected, check_dtypes=False)
     self.assertAllClose(ans2, expected, check_dtypes=False)
예제 #2
0
    def testNestedBasic(self):
        f = lambda x: psum(psum(x, 'i'), 'j')
        f = pmap(pmap(f, 'i'), 'j')

        def sum_and_broadcast(x, axis):
            return onp.repeat(onp.sum(x, axis, keepdims=True), x.shape[axis],
                              axis)

        shape = (xla_bridge.device_count(), 1, 4)
        x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)

        ans = f(x)
        expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
        self.assertAllClose(ans, expected, check_dtypes=False)
예제 #3
0
    def testBasic(self):
        f = pmap(lambda x: x - psum(x, 'i'), axis_name='i')

        shape = (xla_bridge.device_count(), 4)
        x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)
        expected = x - onp.sum(x, 0)

        ans = f(x)
        self.assertAllClose(ans, expected, check_dtypes=False)
예제 #4
0
 def mapped_update(i, opt_state, batch, rng):
     """This is a multi-device version of the update function above."""
     # We assume all tensors have the first dimension = num_devices.
     _, opt_update = optimizer(lr_fun)
     params = trax_opt.get_params(opt_state)
     grads = backend.grad(loss_fun)(params, batch, predict_fun, rng)
     grads = jax.tree_util.tree_map(lambda g: lax_parallel.psum(g, "batch"),
                                    grads)
     return opt_update(i, grads, opt_state)
예제 #5
0
    def testSum(self):
        pfun, axis_name = papply(lambda x: np.sum(x, axis=0), 5)

        jaxpr = make_jaxpr(pfun)(onp.ones(3))
        expected_jaxpr = make_jaxpr(lambda x: lax_parallel.psum(x, axis_name))(
            onp.zeros((5, 3)))
        assert repr(jaxpr) == repr(expected_jaxpr)

        arg = onp.arange(15.).reshape((5, 3))
        ans = serial_pmap(pfun, axis_name)(arg)[0]
        expected = onp.sum(arg, axis=0)
        self.assertAllClose(ans, expected, check_dtypes=False)
예제 #6
0
  def testLogSoftmax(self):
    return SkipTest("test doesn't pass yet")  # TODO(frostig)

    def fun(x):
      return x - np.log(np.sum(np.exp(x)))

    pfun, axis_name = _papply(fun, 5)

    jaxpr = make_jaxpr(pfun)(onp.zeros(5))
    expected_jaxpr = make_jaxpr(
        lambda x: x - np.log(lax_parallel.psum(np.exp(x), axis_name)))(onp.zeros(5))
    assert repr(jaxpr) == repr(expected_jaxpr)

    ans = _serial_pmap(pfun, axis_name)(onp.arange(1., 5.))
    expected = fun(onp.arange(1., 5.))
    self.assertAllClose(ans, expected, check_dtypes=False)
예제 #7
0
    def testPsumMultiple(self):
        f = lambda x: psum(x, ('i', 'j'))
        f = pmap(pmap(f, 'i'), 'j')

        def sum_and_broadcast(x, axis):
            return onp.repeat(onp.sum(x, axis, keepdims=True), x.shape[axis],
                              axis)

        device_count = xla_bridge.device_count()
        num_pairs, ragged = divmod(device_count, 2)
        if num_pairs > 1 and not ragged:
            shape = (num_pairs, 2, 4)
        else:
            shape = (device_count, 1, 4)
        x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)

        ans = f(x)
        expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
        self.assertAllClose(ans, expected, check_dtypes=False)
예제 #8
0
 def testLogSoftmax(self):
     f = lambda x: x - np.log(lax_parallel.psum(np.exp(x), 'i'))
     x = onp.log(onp.arange(1., 10., dtype=onp.float32))
     ans = serial_pmap(f, axis_name='i')(x)
     expected = x - onp.log(onp.sum(onp.exp(x)))
     self.assertAllClose(ans, expected, check_dtypes=False)
예제 #9
0
 def testReduceSum(self):
     f = lambda x: lax_parallel.psum(x, 'i')
     ans = serial_pmap(f, axis_name='i')(onp.ones(4))
     expected = 4 * onp.ones(4)
     self.assertAllClose(ans, expected, check_dtypes=False)
예제 #10
0
 def f(x, y):
     return psum(5. * np.cos(x) * np.sin(y), 'i')