Esempio n. 1
0
 def testNested(self):
   f = lambda x: lax_parallel.psum(lax_parallel.psum(x, 'i'), 'j')
   x = onp.ones((2, 2))
   ans1 = _serial_pmap(_serial_pmap(f, 'i'), 'j')(x)
   ans2 = _serial_pmap(_serial_pmap(f, 'j'), 'i')(x)
   expected = 4 * onp.ones((2, 2))
   self.assertAllClose(ans1, expected, check_dtypes=False)
   self.assertAllClose(ans2, expected, check_dtypes=False)
Esempio n. 2
0
  def testAdd(self):
    x = onp.array([[1, 2, 3], [4, 5, 6]])
    expected = x + x

    pfun, axis_name = _papply(np.add, 2)
    ans = _serial_pmap(pfun, axis_name)(x, x)
    self.assertAllClose(ans, expected, check_dtypes=True)
Esempio n. 3
0
  def testTransposeAndAddRank3(self):

    def fun(x):
      return x + x.T

    x = onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2))
    expected = x + x.T

    pfun, axis_name = _papply(fun, 2)
    ans = _serial_pmap(pfun, axis_name)(x)
    self.assertAllClose(ans, expected, check_dtypes=False)
Esempio n. 4
0
  def testAddBroadcasting(self):
    return SkipTest("test doesn't pass yet")  # TODO(frostig)

    def fun(x):
      return x + 3

    x = onp.array([[1, 2], [3, 4]])
    expected = x + 3

    pfun, axis_name = _papply(fun, 2)
    ans = _serial_pmap(pfun, axis_name)(x)
    self.assertAllClose(ans, expected, check_dtypes=True)
Esempio n. 5
0
  def testMax(self):
    pfun, axis_name = _papply(lambda x: np.max(x, axis=0), 5)

    jaxpr = make_jaxpr(pfun)(onp.ones(3))
    expected_jaxpr = make_jaxpr(
        lambda x: lax_parallel.pmax(x, axis_name))(onp.zeros((5, 3)))
    assert repr(jaxpr) == repr(expected_jaxpr)

    arg = onp.arange(15.).reshape((5, 3))
    ans = _serial_pmap(pfun, axis_name)(arg)[0]
    expected = onp.max(arg, axis=0)
    self.assertAllClose(ans, expected, check_dtypes=False)
Esempio n. 6
0
  def testTransposeWithOddPermutation(self):

    def fun(x):
      return np.transpose(x, (2, 0, 1))

    xs = [
        onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2)),
        onp.reshape(onp.arange(27., dtype=onp.float32), (3, 3, 3)),
    ]
    for x in xs:
      expected = np.transpose(x, (2, 0, 1))
      pfun, axis_name = _papply(fun, x.shape[0])
      ans = _serial_pmap(pfun, axis_name)(x)
      self.assertAllClose(ans, expected, check_dtypes=False)
Esempio n. 7
0
  def testTranspose(self):

    def fun(x):
      return x.T

    xs = [
        onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)),
        onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)),
    ]
    for x in xs:
      expected = x.T
      pfun, axis_name = _papply(fun, x.shape[0])
      ans = _serial_pmap(pfun, axis_name)(x)
      self.assertAllClose(ans, expected, check_dtypes=False)
Esempio n. 8
0
  def testDot(self):
    return SkipTest("test doesn't pass yet")  # TODO(frostig)

    def fun(x, y):
      return lax.dot(x, y)
    xs = [
        onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)),
        onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)),
    ]
    in_axes_combos = [(0, 0), (0, 1)] # [(1, 0)]
    for in_axes in in_axes_combos:
      for x in xs:
        expected = fun(x, x)
        pfun, axis_name = _papply(fun, x.shape[0], in_axes=in_axes)
        ans = _serial_pmap(pfun, axis_name)(x, x)
        self.assertAllClose(ans, expected, check_dtypes=False)
Esempio n. 9
0
  def testLogSoftmax(self):
    return SkipTest("test doesn't pass yet")  # TODO(frostig)

    def fun(x):
      return x - np.log(np.sum(np.exp(x)))

    pfun, axis_name = _papply(fun, 5)

    jaxpr = make_jaxpr(pfun)(onp.zeros(5))
    expected_jaxpr = make_jaxpr(
        lambda x: x - np.log(lax_parallel.psum(np.exp(x), axis_name)))(onp.zeros(5))
    assert repr(jaxpr) == repr(expected_jaxpr)

    ans = _serial_pmap(pfun, axis_name)(onp.arange(1., 5.))
    expected = fun(onp.arange(1., 5.))
    self.assertAllClose(ans, expected, check_dtypes=False)
Esempio n. 10
0
  def testSelect(self):
    pfun, axis_name = _papply(lax.select, 5,
                             in_axes=(None, 0, None))

    p = onp.arange(15).reshape((5, 3)) % 4 == 1
    t = onp.ones((5, 3))
    f = onp.zeros((5, 3))
    jaxpr = make_jaxpr(pfun)(p, t[0], f)

    def expected_spmd(p, t, f):
      return lax.select(
          lax_parallel.psplit_like(p, t, axis_name),
          t,
          lax_parallel.psplit_like(f, t, axis_name))

    expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f)
    assert repr(jaxpr) == repr(expected_jaxpr)

    ans = _serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f)
    expected = lax.select(p, t, f)
    self.assertAllClose(ans, expected, check_dtypes=True)
Esempio n. 11
0
 def testLogSoftmax(self):
   f = lambda x: x - np.log(lax_parallel.psum(np.exp(x), 'i'))
   x = onp.log(onp.arange(1., 10., dtype=onp.float32))
   ans = _serial_pmap(f, axis_name='i')(x)
   expected = x - onp.log(onp.sum(onp.exp(x)))
   self.assertAllClose(ans, expected, check_dtypes=False)
Esempio n. 12
0
 def testPsplitLike(self):
   f = lambda x, y: lax_parallel.psplit_like(x, y, 'i')
   arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
   ans = _serial_pmap(f, axis_name='i', in_axes=(None, 2), out_axes=2)(arg, arg)
   expected = arg
   self.assertAllClose(ans, expected, check_dtypes=False)
Esempio n. 13
0
 def testReduceMax(self):
   f = lambda x: lax_parallel.pmax(x, 'i')
   ans = _serial_pmap(f, axis_name='i')(onp.arange(4))
   expected = 3 * onp.ones(4)
   self.assertAllClose(ans, expected, check_dtypes=False)
Esempio n. 14
0
 def testConstantFunction(self):
   f = lambda x: 3
   ans = _serial_pmap(f, axis_name='i')(onp.ones(4))
   expected = 3 * onp.ones(4)
   self.assertAllClose(ans, expected, check_dtypes=False)
Esempio n. 15
0
 def testPsplit(self):
   f = lambda x: lax.psplit(x, 'i', 2)
   arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
   ans = _serial_pmap(f, axis_name='i', out_axes=2)(arg)
   expected = arg
   self.assertAllClose(ans, expected, check_dtypes=False)
Esempio n. 16
0
 def testReduceSum(self):
   f = lambda x: lax.psum(x, 'i')
   ans = _serial_pmap(f, axis_name='i')(onp.ones(4))
   expected = 4 * onp.ones(4)
   self.assertAllClose(ans, expected, check_dtypes=False)