Example #1
0
  def testAdd(self):
    x = onp.array([[1, 2, 3], [4, 5, 6]])
    expected = x + x

    pfun, axis_name = papply(np.add, 2)
    ans = serial_pmap(pfun, axis_name)(x, x)
    self.assertAllClose(ans, expected, check_dtypes=True)
Example #2
0
    def testSum(self):
        pfun, axis_name = papply(np.sum)

        jaxpr = make_jaxpr(pfun)(onp.zeros(5))
        expected_jaxpr = make_jaxpr(lambda x: psum(x, axis_name))(onp.zeros(5))
        assert repr(jaxpr) == repr(expected_jaxpr)

        ans = pmap(pfun, axis_name)(onp.arange(3.))
        expected = onp.sum(onp.arange(3.))
        self.assertAllClose(ans, expected, check_dtypes=False)
Example #3
0
    def testAddBroadcasting(self):
        def fun(x):
            return x + 3

        x = onp.array([[1, 2], [3, 4]])
        expected = x + 3

        pfun, axis_name = papply(fun)
        ans = pmap(pfun, axis_name)(x)
        self.assertAllClose(ans, expected, check_dtypes=True)
Example #4
0
    def testTransposeAndAddRank3(self):
        def fun(x):
            return x + x.T

        x = onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2))
        expected = x + x.T

        pfun, axis_name = papply(fun, 2)
        ans = serial_pmap(pfun, axis_name)(x)
        self.assertAllClose(ans, expected, check_dtypes=False)
Example #5
0
    def testSum(self):
        pfun, axis_name = papply(lambda x: np.sum(x, axis=0), 5)

        jaxpr = make_jaxpr(pfun)(onp.ones(3))
        expected_jaxpr = make_jaxpr(lambda x: lax_parallel.psum(x, axis_name))(
            onp.zeros((5, 3)))
        assert repr(jaxpr) == repr(expected_jaxpr)

        arg = onp.arange(15.).reshape((5, 3))
        ans = serial_pmap(pfun, axis_name)(arg)[0]
        expected = onp.sum(arg, axis=0)
        self.assertAllClose(ans, expected, check_dtypes=False)
Example #6
0
    def testAddBroadcasting(self):
        return SkipTest("test doesn't pass yet")  # TODO(frostig)

        def fun(x):
            return x + 3

        x = onp.array([[1, 2], [3, 4]])
        expected = x + 3

        pfun, axis_name = papply(fun, 2)
        ans = serial_pmap(pfun, axis_name)(x)
        self.assertAllClose(ans, expected, check_dtypes=True)
Example #7
0
    def testTransposeWithOddPermutation(self):
        def fun(x):
            return np.transpose(x, (2, 0, 1))

        xs = [
            onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2)),
            onp.reshape(onp.arange(27., dtype=onp.float32), (3, 3, 3)),
        ]
        for x in xs:
            expected = np.transpose(x, (2, 0, 1))
            pfun, axis_name = papply(fun, x.shape[0])
            ans = serial_pmap(pfun, axis_name)(x)
            self.assertAllClose(ans, expected, check_dtypes=False)
Example #8
0
    def testTranspose(self):
        def fun(x):
            return x.T

        xs = [
            onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)),
            onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)),
        ]
        for x in xs:
            expected = x.T
            pfun, axis_name = papply(fun, x.shape[0])
            ans = serial_pmap(pfun, axis_name)(x)
            self.assertAllClose(ans, expected, check_dtypes=False)
Example #9
0
    def testLogSoftmax(self):
        def fun(x):
            return x - np.log(np.sum(np.exp(x)))

        pfun, axis_name = papply(fun)

        jaxpr = make_jaxpr(pfun)(onp.zeros(5))
        expected_jaxpr = make_jaxpr(
            lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5))
        assert repr(jaxpr) == repr(expected_jaxpr)

        ans = pmap(pfun, axis_name)(onp.arange(1., 5.))
        expected = fun(onp.arange(1., 5.))
        self.assertAllClose(ans, expected, check_dtypes=False)
Example #10
0
  def testDot(self):

    def fun(x, y):
      return lax.dot(x, y)
    xs = [
        onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)),
        onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)),
    ]
    in_axes_combos = [(0, 0), (0, 1)] # [(1, 0)]
    for in_axes in in_axes_combos:
      for x in xs:
        expected = fun(x, x)
        pfun, axis_name = papply(fun, x.shape[0], in_axes=in_axes)
        ans = serial_pmap(pfun, axis_name)(x, x)
        self.assertAllClose(ans, expected, check_dtypes=False)
Example #11
0
    def testLogSoftmax(self):
        return SkipTest("test doesn't pass yet")  # TODO(frostig)

        def fun(x):
            return x - np.log(np.sum(np.exp(x)))

        pfun, axis_name = papply(fun, 5)

        jaxpr = make_jaxpr(pfun)(onp.zeros(5))
        expected_jaxpr = make_jaxpr(
            lambda x: x - np.log(lax_parallel.psum(np.exp(x), axis_name)))(
                onp.zeros(5))
        assert repr(jaxpr) == repr(expected_jaxpr)

        ans = serial_pmap(pfun, axis_name)(onp.arange(1., 5.))
        expected = fun(onp.arange(1., 5.))
        self.assertAllClose(ans, expected, check_dtypes=False)
Example #12
0
    def testSelect(self):
        pfun, axis_name = papply(lax.select, 5, in_axes=(None, 0, None))

        p = onp.arange(15).reshape((5, 3)) % 4 == 1
        t = onp.ones((5, 3))
        f = onp.zeros((5, 3))
        jaxpr = make_jaxpr(pfun)(p, t[0], f)

        def expected_spmd(p, t, f):
            return lax.select(lax_parallel.psplit_like(p, t, axis_name), t,
                              lax_parallel.psplit_like(f, t, axis_name))

        expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f)
        assert repr(jaxpr) == repr(expected_jaxpr)

        ans = serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f)
        expected = lax.select(p, t, f)
        self.assertAllClose(ans, expected, check_dtypes=True)
Example #13
0
 def testMap(self):
   pfun, axis_name = papply(np.sin, 3)
   ans = pfun(onp.arange(3.))
   expected = onp.sin(onp.arange(3.))
   self.assertAllClose(ans, expected, check_dtypes=False)
Example #14
0
 def testIdentity(self):
   pfun, axis_name = papply(lambda x: x, 3)
   ans = pfun(onp.arange(3))
   expected = onp.arange(3)
   self.assertAllClose(ans, expected, check_dtypes=False)