コード例 #1
0
ファイル: pmap_test.py プロジェクト: ziyadedher/jax
    def testUnmaterializedAxis(self):
        shape = (4, 8)
        spec = pxla.ShardingSpec(shards_per_axis=(4, 1),
                                 is_axis_materialized=(False, True),
                                 replication_factor=1)
        self.assertEqual(pxla.spec_to_indices(shape, spec), (0, 1, 2, 3))

        shape = (2, 2)
        spec = pxla.ShardingSpec(shards_per_axis=(1, 2),
                                 is_axis_materialized=(True, False),
                                 replication_factor=1)
        self.assertEqual(pxla.spec_to_indices(shape, spec),
                         ((slice(None), 0), (slice(None), 1)))
コード例 #2
0
ファイル: pmap_test.py プロジェクト: mattwescott/jax
 def testReplication(self):
   shape = (2, 8)
   spec = pxla.ShardingSpec(shards_per_axis=(2, 1),
                            is_axis_materialized=(False, True),
                            replication_factor=3)
   self.assertEqual(pxla.spec_to_indices(shape, spec),
                    (0, 0, 0, 1, 1, 1))
コード例 #3
0
ファイル: pmap_test.py プロジェクト: mattwescott/jax
 def testNoSharding(self):
   shape = (4,8)
   spec = pxla.ShardingSpec(shards_per_axis=(1, 1),
                            is_axis_materialized=(True, True),
                            replication_factor=1)
   self.assertEqual(pxla.spec_to_indices(shape, spec),
                    (slice(None),))
コード例 #4
0
ファイル: pmap_test.py プロジェクト: mattwescott/jax
 def testUnshardedAxis(self):
   shape = (4, 8)
   spec = pxla.ShardingSpec(shards_per_axis=(2, 1),
                            is_axis_materialized=(True, True),
                            replication_factor=1)
   self.assertEqual(pxla.spec_to_indices(shape, spec),
                    (slice(0,2), (slice(2,4))))
コード例 #5
0
ファイル: xmap_test.py プロジェクト: rrstal/jax
 def testOneLogicalTwoMeshAxesSharding(self):
   def f(v):
     return v * 4
   fxy = xmap(f, in_axes=['a', ...], out_axes={1: 'a'},
              axis_resources={'a': ('x', 'y')})
   fyx = xmap(f, in_axes=['a', ...], out_axes={1: 'a'},
              axis_resources={'a': ('y', 'x')})
   vshape = (4, 5)
   v = jnp.arange(np.prod(vshape)).reshape(vshape)
   zxy = fxy(v)
   self.assertEqual(
       zxy.sharding_spec,
       pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
                         (pxla.ShardedAxis(0), pxla.ShardedAxis(1))))
   zyx = fyx(v)
   self.assertEqual(
       zyx.sharding_spec,
       pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
                         (pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
コード例 #6
0
ファイル: pmap_test.py プロジェクト: ziyadedher/jax
    def testReshardInput(self):
        if xla_bridge.device_count() < 6:
            raise SkipTest("testReshardInput requires 6 devices")
        # Manually construct a ShardedDeviceArray with the wrong sharding for the
        # subsequent pmap
        shard_shape = (3, 2)
        shard = np.arange(np.prod(shard_shape)).reshape(shard_shape)
        bufs = [xla.device_put(shard, d) for d in xla_bridge.devices()[:4]]
        aval = ShapedArray((6, 4), shard.dtype)
        sharding_spec = pxla.ShardingSpec(shards_per_axis=(2, 2),
                                          is_axis_materialized=(True, True),
                                          replication_factor=2)
        arr = pxla.ShardedDeviceArray(aval, sharding_spec, bufs)

        r = pmap(lambda x: x + 1)(arr)
        self.assertAllClose(r, arr + 1, check_dtypes=True)
        self.assertEqual(len(r.device_buffers), 6)