Beispiel #1
0
    def testAdd(self):
        x = onp.array([[1, 2, 3], [4, 5, 6]])
        expected = x + x

        pfun, axis_name = _papply(np.add)
        ans = soft_pmap(pfun, axis_name)(x, x)
        self.assertAllClose(ans, expected, check_dtypes=True)
Beispiel #2
0
    def testSelect(self):
        p = onp.arange(15).reshape((5, 3)) % 4 == 1
        f = onp.zeros((5, 3))

        def fun(t):
            return lax.select(p, t, f)

        t = onp.ones((5, 3))
        ans = soft_pmap(*_papply(fun))(t)
        expected = fun(t)
        self.assertAllClose(ans, expected, check_dtypes=True)
Beispiel #3
0
  def testTransposeAndAddRank3(self):

    def fun(x):
      return x + x.T

    x = onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2))
    expected = x + x.T

    pfun, axis_name = _papply(fun, 2)
    ans = _serial_pmap(pfun, axis_name)(x)
    self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #4
0
    def testMax(self):
        pfun, axis_name = _papply(lambda x: np.max(x, axis=0))

        jaxpr = make_jaxpr(pfun)(onp.ones(3))
        expected_jaxpr = make_jaxpr(lambda x: lax.pmax(x, axis_name))(
            onp.zeros((5, 3)))
        assert repr(jaxpr) == repr(expected_jaxpr)

        arg = onp.arange(15.).reshape((5, 3))
        ans = soft_pmap(pfun, axis_name)(arg)[0]
        expected = onp.max(arg, axis=0)
        self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #5
0
    def testDot(self):
        raise SkipTest("known failure")  # TODO(frostig)
        x = onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2))

        def fun(x, y):
            return lax.dot(x, y)

        expected = fun(x, x)
        pfun, axis_name = _papply(fun)
        ans = soft_pmap(pfun, axis_name)(x, x)
        ans = self.dedup(ans, expected.ndim)
        self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #6
0
    def testAddBroadcasting(self):
        raise SkipTest("test doesn't pass yet")  # TODO(frostig)

        def fun(x):
            return x + 3

        x = onp.array([[1, 2], [3, 4]])
        expected = x + 3

        pfun, axis_name = _papply(fun)
        ans = soft_pmap(pfun, axis_name)(x)
        self.assertAllClose(ans, expected, check_dtypes=True)
Beispiel #7
0
  def testSum(self):
    pfun, axis_name = _papply(lambda x: np.sum(x, axis=0), 5)

    jaxpr = make_jaxpr(pfun)(onp.ones(3))
    expected_jaxpr = make_jaxpr(
        lambda x: lax_parallel.psum(x, axis_name))(onp.zeros((5, 3)))
    assert repr(jaxpr) == repr(expected_jaxpr)

    arg = onp.arange(15.).reshape((5, 3))
    ans = _serial_pmap(pfun, axis_name)(arg)[0]
    expected = onp.sum(arg, axis=0)
    self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #8
0
    def testDotGeneral(self, matching, coloring, split):
        BATCH, CONTRACT, _ = range(3)
        SPLIT_LHS, SPLIT_RHS, SPLIT_BOTH = range(3)

        x = onp.reshape(onp.arange(8.), (2, 2, 2))
        y = onp.reshape(onp.arange(8.), (2, 2, 2)) + 4.

        cdims = [(i, matching[i]) for i in range(3) if coloring[i] == CONTRACT]
        bdims = [(i, matching[i]) for i in range(3) if coloring[i] == BATCH]
        dimension_numbers = [
            list(zip(*cdims)) or [(), ()],
            list(zip(*bdims)) or [(), ()]
        ]

        def f(x, y):
            return lax.dot_general(x, y, dimension_numbers)

        if split == SPLIT_LHS:
            fun = lambda x: f(x, y)
        elif split == SPLIT_RHS:
            fun = lambda y: f(x, y)
        else:
            fun = f

        try:
            if split != SPLIT_BOTH:
                expected = fun(x)
                pfun, axis_name = _papply(fun)
                ans = soft_pmap(pfun, axis_name)(x)
            else:
                expected = fun(x, y)
                pfun, axis_name = _papply(fun)
                ans = soft_pmap(pfun, axis_name)(x, y)
        except (NotImplementedError, TypeError) as e:
            raise SkipTest(str(e)) from e

        ans = self.dedup(ans, expected.ndim)
        self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #9
0
  def testTransposeWithOddPermutation(self):

    def fun(x):
      return np.transpose(x, (2, 0, 1))

    xs = [
        onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2)),
        onp.reshape(onp.arange(27., dtype=onp.float32), (3, 3, 3)),
    ]
    for x in xs:
      expected = np.transpose(x, (2, 0, 1))
      pfun, axis_name = _papply(fun, x.shape[0])
      ans = _serial_pmap(pfun, axis_name)(x)
      self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #10
0
  def testTranspose(self):

    def fun(x):
      return x.T

    xs = [
        onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)),
        onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)),
    ]
    for x in xs:
      expected = x.T
      pfun, axis_name = _papply(fun, x.shape[0])
      ans = _serial_pmap(pfun, axis_name)(x)
      self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #11
0
    def testLogSoftmax(self):
        raise SkipTest("test doesn't pass yet")  # TODO(frostig)

        def fun(x):
            return x - np.log(np.sum(np.exp(x)))

        pfun, axis_name = _papply(fun)

        jaxpr = make_jaxpr(pfun)(onp.zeros(5))
        expected_jaxpr = make_jaxpr(
            lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5))
        assert repr(jaxpr) == repr(expected_jaxpr)

        ans = soft_pmap(pfun, axis_name)(onp.arange(1., 5.))
        expected = fun(onp.arange(1., 5.))
        self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #12
0
  def testDot(self):
    return SkipTest("test doesn't pass yet")  # TODO(frostig)

    def fun(x, y):
      return lax.dot(x, y)
    xs = [
        onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)),
        onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)),
    ]
    in_axes_combos = [(0, 0), (0, 1)] # [(1, 0)]
    for in_axes in in_axes_combos:
      for x in xs:
        expected = fun(x, x)
        pfun, axis_name = _papply(fun, x.shape[0], in_axes=in_axes)
        ans = _serial_pmap(pfun, axis_name)(x, x)
        self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #13
0
  def testSelect(self):
    pfun, axis_name = _papply(lax.select, 5,
                             in_axes=(None, 0, None))

    p = onp.arange(15).reshape((5, 3)) % 4 == 1
    t = onp.ones((5, 3))
    f = onp.zeros((5, 3))
    jaxpr = make_jaxpr(pfun)(p, t[0], f)

    def expected_spmd(p, t, f):
      return lax.select(
          lax_parallel.psplit_like(p, t, axis_name),
          t,
          lax_parallel.psplit_like(f, t, axis_name))

    expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f)
    assert repr(jaxpr) == repr(expected_jaxpr)

    ans = _serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f)
    expected = lax.select(p, t, f)
    self.assertAllClose(ans, expected, check_dtypes=True)
Beispiel #14
0
 def testMap(self):
     pfun, axis_name = _papply(np.sin)
     ans = pfun(onp.arange(3.))
     expected = onp.sin(onp.arange(3.))
     self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #15
0
 def testIdentity(self):
     pfun, axis_name = _papply(lambda x: x)
     ans = pfun(onp.arange(3))
     expected = onp.arange(3)
     self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #16
0
 def testMakeJaxprPapplyComposition(self):
     raise SkipTest(  # TODO(mattjj)
         "fails because select's papply rule calls an SPMD primitive")
     x = b = onp.ones(3)
     pfun, axis_name = _papply(lambda a: np.where(x, a, b))
     make_jaxpr(pfun)(onp.ones(3))  # doesn't crash