Esempio n. 1
0
    def testVmapOfPmapTuple(self):
        device_count = xla_bridge.device_count()
        f0 = lambda *x: x
        f1 = pmap(f0, axis_name='i')

        ax = onp.random.randn(device_count, 2, 50, 60)
        ay = onp.random.randn(device_count, 30, 2)
        az1 = onp.random.randn(device_count, 20)
        az2 = onp.random.randn(2, device_count, 20)

        bx, by, bz = vmap(f1, in_axes=(1, 2, (None, 0)),
                          out_axes=(1, 2, 0))(ax, ay, (az1, az2))

        self.assertAllClose(ax, bx, check_dtypes=False)
        self.assertAllClose(ay, by, check_dtypes=False)

        bz1, bz2 = bz
        expected_bz1 = onp.broadcast_to(az1, (2, ) + az1.shape)
        self.assertAllClose(expected_bz1, bz1, check_dtypes=False)
        self.assertAllClose(bz2, bz2, check_dtypes=False)
Esempio n. 2
0
    def testBadAxisSizeError(self):
        if xla_bridge.device_count() == 1:
            raise SkipTest("this test requires multiple devices")

        f = pmap(lambda x: lax.psum(x, 'i'),
                 axis_name='i',
                 devices=xla_bridge.devices())
        with self.assertRaisesRegex(
                ValueError,
                r"Leading axis size of input to pmapped function must "
                r"equal the number of local devices passed to pmap. Got axis_size=1, "
                r"num_local_devices=\d."):
            f(np.ones(1))

        with self.assertRaisesRegex(
                ValueError,
                r"Leading axis size of input to pmapped function must "
                r"equal the number of local devices passed to pmap. Got axis_size=\d, "
                r"num_local_devices=\d."):
            f(np.ones(xla_bridge.device_count() + 1))
Esempio n. 3
0
 def testCollectivesWithTreesOfDifferentDtypes(self):
     n = len(jax.devices())
     x = {
         'a':
         onp.arange(1 * n * n, 2 * n * n, dtype=onp.float32).reshape([n,
                                                                      n]),
         'b':
         onp.arange(2 * n * n, 3 * n * n, dtype=onp.int32).reshape([n, n]),
         'c':
         onp.arange(4 * n * n, 5 * n * n, dtype=onp.float32).reshape([n,
                                                                      n]),
         'd':
         onp.arange(6 * n * n, 7 * n * n, dtype=onp.int32).reshape([n, n])
     }
     tree_f = lambda f: partial(tree_util.tree_map, f)
     jax_f = lambda p: pmap(lambda x: p(x, 'i'), 'i')
     onp_f = lambda p: tree_f(lambda x: onp.broadcast_to(p(x, 0), x.shape))
     assert_allclose = partial(
         tree_util.tree_multimap,
         partial(self.assertAllClose, check_dtypes=False))
     assert_allclose(jax_f(lax.pmax)(x), onp_f(onp.max)(x))
     assert_allclose(jax_f(lax.pmin)(x), onp_f(onp.min)(x))
     assert_allclose(jax_f(lax.psum)(x), onp_f(onp.sum)(x))
     assert_allclose(jax_f(lax.pmean)(x), onp_f(onp.mean)(x))
Esempio n. 4
0
        def f_pmapped(x, *args, **kwargs):
            args_np, args_np_idxs = [], []
            args_other = {}

            # TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it.
            # https://github.com/google/jax/issues/912
            # Filter out `np.ndarray`s from other arguments.
            for i, arg in enumerate(args):
                if _is_np_ndarray(arg):
                    args_np.append(arg)
                    args_np_idxs.append(i)
                else:
                    args_other[i] = arg

            # Check cache before jitting.
            _key = key + tuple(args_other.items()) + tuple(kwargs.items())
            if _key in cache:
                _f = cache[_key]
            else:
                # Define a `np.ndarray`-only function as a closure over other arguments.
                def _f(_x, *_args_np):
                    # Merge args.
                    _args_np = {
                        i: _arg_np
                        for i, _arg_np in zip(args_np_idxs, _args_np)
                    }
                    _args = _merge_dicts(_args_np, args_other)
                    _args = tuple(v for k, v in sorted(_args.items()))
                    return f(_x, *_args, **kwargs)

                _f = jit(_f) if device_count == 0 else pmap(_f)
                cache[_key] = _f

            # Broadcast `np.ndarray` arguments and apply the new function to them.
            args_np = tree_map(broadcast, args_np)
            return _f(x, *args_np)
Esempio n. 5
0
    def testNestedPmapConstant(self):
        if xla_bridge.device_count() == 1:
            raise SkipTest("this test requires multiple devices")

        f = pmap(pmap(lambda x: 3))
        shape = (2, xla_bridge.device_count() // 2, 3)
        x = np.arange(prod(shape)).reshape(shape)
        ans = f(x)
        expected = 3 * onp.ones(shape[:2])
        self.assertAllClose(ans, expected, check_dtypes=False)

        # Test that 'ans' was properly replicated across devices.
        expected_sharded = pmap(pmap(lambda x: x))(expected)
        self.assertEqual([b.device() for b in ans.device_buffers],
                         [b.device() for b in expected_sharded.device_buffers])

        f = pmap(pmap(lambda x: (x, 3)))
        x_sharded, ans = f(x)
        self.assertAllClose(ans, expected, check_dtypes=False)
        self.assertEqual([b.device() for b in ans.device_buffers],
                         [b.device() for b in x_sharded.device_buffers])
Esempio n. 6
0
File: api_test.py Progetto: yyht/jax
 def f(x):
     return api.pmap(np.mean)(x)
Esempio n. 7
0
 def g(z):
   return pmap(lambda x: x * y)(z)
Esempio n. 8
0
 def testShardedDeviceArrayBlockUntilReady(self):
   x = onp.arange(xla_bridge.device_count())
   x = pmap(lambda x: x)(x)
   x.block_until_ready()  # doesn't crash
Esempio n. 9
0
 def distributed_matrix_vector(x, y):
   """Matrix vector multiply. First batch it and then row by row"""
   fv = lambda z: lax.map(lambda j: vv(j, y), z)
   res = pmap(fv)(x.reshape((jax.device_count(), -1) + tuple(x.shape[1:])))
   res = res.reshape(res.shape[0] * res.shape[1], *res.shape[2:])
   return res
Esempio n. 10
0
 def pmvm(a, b):
   a = a.reshape((nrep, -1, a.shape[1]))
   func = pmap(lambda z: np.dot(z, b))
   return func(a).reshape(b.shape)
Esempio n. 11
0
 def testMismatchedAxisSizes(self):
   n = xla_bridge.device_count()
   f = pmap(lambda x, y: x + y)
   jtu.check_raises_regexp(
       lambda: f(onp.random.randn(n), onp.random.randn(n - 1)), ValueError,
       "Axis size .* does not match leading dimension of shape .*")
Esempio n. 12
0
 def testReduceSum(self):
     f = lambda x: lax.psum(x, 'i')
     ans = pmap(f, axis_name='i')(onp.ones(4))
     expected = 4 * onp.ones(4)
     self.assertAllClose(ans, expected, check_dtypes=False)
Esempio n. 13
0
 def testConstantFunction(self):
     f = lambda x: 3
     ans = pmap(f, axis_name='i')(onp.ones(4))
     expected = 3 * onp.ones(4)
     self.assertAllClose(ans, expected, check_dtypes=False)
Esempio n. 14
0
def main(unused_argv):
    from jax.api import grad, jit, vmap, pmap, device_put
    "The following is required to use TPU Driver as JAX's backend."

    if FLAGS.TPU:
        config.FLAGS.jax_xla_backend = "tpu_driver"
        config.FLAGS.jax_backend_target = "grpc://" + os.environ[
            'TPU_ADDR'] + ':8470'
        TPU_ADDR = os.environ['TPU_ADDR']
    ndevices = xla_bridge.device_count()
    if not FLAGS.TPU:
        ndevices = 1

    pmap = partial(pmap, axis_name='i')
    """Setup some experiment parameters."""
    meas_step = FLAGS.meas_step
    training_epochs = int(FLAGS.epochs)

    tmult = 1.0
    if FLAGS.physical:
        tmult = FLAGS.lr
        if FLAGS.physicalL2:
            tmult = FLAGS.L2 * tmult
    if FLAGS.physical:
        training_epochs = 1 + int(FLAGS.epochs / tmult)

    print('Evolving for {:}e'.format(training_epochs))
    losst = FLAGS.losst
    learning_rate = FLAGS.lr
    batch_size_per_device = FLAGS.bs
    N = FLAGS.N
    K = FLAGS.K

    batch_size = batch_size_per_device * ndevices
    steps_per_epoch = 50000 // batch_size
    training_steps = training_epochs * steps_per_epoch

    "Filename from FLAGS"

    filename = 'wrnL2_' + losst + '_n' + str(N) + '_k' + str(K)
    if FLAGS.momentum:
        filename += '_mom'
    if FLAGS.L2_sch:
        filename += '_L2sch' + '_decay' + str(FLAGS.L2dec) + '_del' + str(
            FLAGS.delay)
    if FLAGS.seed != 1:
        filename += 'seed' + str(FLAGS.seed)
    filename += '_L2' + str(FLAGS.L2)
    if FLAGS.std_wrn_sch:
        filename += '_stddec'
        if FLAGS.physical:
            filename += 'phys'
    else:
        filename += '_ctlr'
    if not FLAGS.augment:
        filename += '_noaug'
    if not FLAGS.mix:
        filename += '_nomixup'
    filename += '_bs' + str(batch_size) + '_lr' + str(learning_rate)
    if FLAGS.jobdir is not None:
        filedir = os.path.join('wrnlogs', FLAGS.jobdir)
    else:
        filedir = 'wrnlogs'
    if not os.path.exists(filedir):
        os.makedirs(filedir)
    filedir = os.path.join(filedir, filename + '.csv')

    print('Saving log to ', filename)
    print('Found {} cores.'.format(ndevices))
    """Load CIFAR10 data and create a minimal pipeline."""

    train_images, train_labels, test_images, test_labels = utils.load_data(
        'cifar10')
    train_images = np.reshape(train_images, (-1, 32, 32 * 3))
    train = (train_images, train_labels)
    test = (test_images, test_labels)
    k = train_labels.shape[-1]
    train = utils.shard_data(train, ndevices)
    test = utils.shard_data(test, ndevices)
    """Create a Wide Resnet and replicate its parameters across the devices."""

    initparams, f, _ = utils.WideResnetnt(N, K, k)

    "Loss and optimizer definitions"

    l2_norm = lambda params: tree_map(lambda x: np.sum(x**2), params)
    l2_reg = lambda params: tree_reduce(lambda x, y: x + y, l2_norm(params))
    currL2 = FLAGS.L2
    L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, )))

    def xentr(params, images_and_labels):
        images, labels = images_and_labels
        return -np.mean(stax.logsoftmax(f(params, images)) * labels)

    def mse(params, data_tuple):
        """MSE loss."""
        x, y = data_tuple
        return 0.5 * np.mean((y - f(params, x))**2)

    if losst == 'xentr':
        print('Using xentr')
        lossm = xentr
    else:
        print('Using mse')
        lossm = mse

    loss = lambda params, data, L2: lossm(params, data) + L2 * l2_reg(params)

    def accuracy(params, images_and_labels):
        images, labels = images_and_labels
        return np.mean(
            np.array(np.argmax(f(params, images), axis=1) == np.argmax(labels,
                                                                       axis=1),
                     dtype=np.float32))

    "Define optimizer"

    if FLAGS.std_wrn_sch:
        lr = learning_rate
        first_epoch = int(60 / 200 * training_epochs)
        learning_rate_fn = optimizers.piecewise_constant(
            np.array([1, 2, 3]) * first_epoch * steps_per_epoch,
            np.array([lr, lr * 0.2, lr * 0.2**2, lr * 0.2**3]))
    else:
        learning_rate_fn = optimizers.make_schedule(learning_rate)

    if FLAGS.momentum:
        momentum = 0.9
    else:
        momentum = 0

    @pmap
    def update_step(step, state, batch_state, L2):
        batch, batch_state = batch_fn(batch_state)
        params = get_params(state)
        dparams = grad_loss(params, batch, L2)
        dparams = tree_map(lambda x: lax.psum(x, 'i') / ndevices, dparams)
        return step + 1, apply_fn(step, dparams, state), batch_state

    @pmap
    def evaluate(state, data, L2):
        params = get_params(state)
        lossmm = lossm(params, data)
        l2mm = l2_reg(params)
        return lossmm + L2 * l2mm, accuracy(params, data), lossmm, l2mm

    "Initialization and loading"

    _, params = initparams(random.PRNGKey(0), (-1, 32, 32, 3))
    replicate_array = lambda x: \
        np.broadcast_to(x, (ndevices,) + x.shape)
    replicated_params = tree_map(replicate_array, params)

    grad_loss = jit(grad(loss))
    init_fn, apply_fn, get_params = optimizers.momentum(
        learning_rate_fn, momentum)
    apply_fn = jit(apply_fn)
    key = random.PRNGKey(FLAGS.seed)

    batchinit_fn, batch_fn = utils.sharded_minibatcher(batch_size,
                                                       ndevices,
                                                       transform=FLAGS.augment,
                                                       k=k,
                                                       mix=FLAGS.mix)

    batch_state = pmap(batchinit_fn)(random.split(key, ndevices), train)
    state = pmap(init_fn)(replicated_params)

    if FLAGS.checkpointing:
        ## Loading of checkpoint if available/provided.
        single_state = init_fn(params)
        i0, load_state, load_params, filename0, batch_stateb = utils.load_weights(
            filename,
            single_state,
            params,
            full_file=FLAGS.load_w,
            ndevices=ndevices)
        if i0 is not None:
            filename = filename0
            if batch_stateb is not None:
                batch_state = batch_stateb
            if load_params is not None:
                state = pmap(init_fn)(load_params)
            else:
                state = load_state
        else:
            i0 = 0
    else:
        i0 = 0

    if FLAGS.steps_from_load:
        training_steps = i0 + training_steps

    batch_xs, _ = pmap(batch_fn)(batch_state)

    train_loss = []
    train_accuracy = []
    lrL = []
    test_loss = []
    test_accuracy = []
    test_L2, test_lm, train_lm, train_L2 = [], [], [], []
    L2_t = []
    idel0 = i0
    start = time.time()

    step = pmap(lambda x: x)(i0 * np.ones((ndevices, )))

    "Start training loop"
    if FLAGS.checkpointing:
        print('Evolving for {:}e and saving every {:}s'.format(
            training_epochs, FLAGS.checkpointing))

    print(
        'Epoch\tLearning Rate\tTrain bareLoss\t L2_norm \tTest Loss\tTrain Error\tTest Error\tTime / Epoch'
    )

    for i in range(i0, training_steps):
        if i % meas_step == 0:
            # Make Measurement
            l, a, lm, L2m = evaluate(state, test, L2p)
            test_loss += [np.mean(l)]
            test_accuracy += [np.mean(a)]
            test_lm += [np.mean(lm)]
            test_L2 += [np.mean(L2m)]
            train_batch, _ = pmap(batch_fn)(batch_state)
            l, a, lm, L2m = evaluate(state, train_batch, L2p)

            train_loss += [np.mean(l)]
            train_accuracy += [np.mean(a)]
            train_lm += [np.mean(lm)]
            train_L2 += [np.mean(L2m)]
            L2_t.append(currL2)
            lrL += [learning_rate_fn(i)]

            if FLAGS.L2_sch and i > FLAGS.delay / currL2 + idel0 and len(
                    train_lm) > 2 and ((minloss <= train_lm[-1]
                                        and minloss <= train_lm[-2]) or
                                       (maxacc >= train_accuracy[-1]
                                        and maxacc >= train_accuracy[-2])):
                # If AutoL2 is on and we are beyond the refractory period, decay if the loss or error have increased in the last two measurements.
                print('Decaying L2 to', currL2 / FLAGS.L2dec)
                currL2 = currL2 / FLAGS.L2dec
                L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, )))
                idel0 = i

            elif FLAGS.L2_sch and len(train_lm) >= 2:
                # Update the minimum values.
                try:
                    maxacc = max(train_accuracy[-2], maxacc)
                    minloss = min(train_lm[-2], minloss)
                except:
                    maxacc, minloss = train_accuracy[-2], train_lm[-2]

            if i % (meas_step * 10) == 0 or i == i0:
                # Save measurements to csv
                epoch = batch_size * i / 50000
                dt = (time.time() - start) / (meas_step * 10) * steps_per_epoch
                print(('{}\t' + ('{: .4f}\t' * 7)).format(
                    epoch, learning_rate_fn(i), train_lm[-1], train_L2[-1],
                    test_loss[-1], train_accuracy[-1], test_accuracy[-1], dt))

                start = time.time()
                data = {
                    'train_loss': train_loss,
                    'test_loss': test_loss,
                    'train_acc': train_accuracy,
                    'test_acc': test_accuracy
                }
                data['train_bareloss'] = train_lm
                data['train_L2'] = train_L2
                data['test_bareloss'] = test_lm
                data['test_L2'] = test_L2
                data['L2_t'] = L2_t
                df = pd.DataFrame(data)

                df['learning_rate'] = lrL
                df['width'] = K
                df['batch_size'] = batch_size
                df['step'] = i0 + onp.arange(0, len(train_loss)) * meas_step

                df.to_csv(filedir, index=False)

        if FLAGS.checkpointing:
            ### SAVE MODEL
            if i % FLAGS.checkpointing == 0 and i > i0:

                if not os.path.exists('weights/'):
                    os.makedirs('weights/')
                saveparams = tree_flatten(state[0])[0]
                if ndevices > 1:
                    saveparams = [el[0] for el in saveparams]
                saveparams = np.concatenate(
                    [el.reshape(-1) for el in saveparams])

                step0 = i
                print('Step', i)
                print('saving at', filename, step0, 'size:', saveparams.shape)

                utils.save_weights(filename, step0, saveparams, batch_state)

        ## UPDATE
        step, state, batch_state = update_step(step, state, batch_state, L2p)

    print('Training done')

    if FLAGS.TPU:
        with open('done/' + TPU_ADDR, 'w') as fp:
            fp.write(filedir)
            pass
Esempio n. 15
0
def _parallel(ker_fun, device_count=-1):
    """Returns a function that computes a kernel in batches in parallel.

  When batching in parallel, the data is split over a set number of devices.
  The number of devices must be less than or equal to the number of physical
  devices. Moreover, the dataset size needs to divide the device count.

  Given two datasets x1 and x2, parallel splits the kernel calculation over
  devices such that each device computes a batch of rows of shape
  [|x1| / device_count, |x2|].

  Args:
    ker_fun: A function that computes a kernel between two datasets,
        ker_fun(x1, x2). Here x1 and x2 are `np.ndarray`s of floats of shape
        [n1,] + input_shape and [n2,] + input_shape. The kernel function
        should return a PyTree.
    device_count: Integer specifying the number of devices over which to split
        the data. If device_count = 0, the computation is parallelized over all
        available devices.

  Returns:
    A new function with the same signature as ker_fun that computes the kernel
    by batching over the dataset in parallel over a specified number of cores.
  """

    if device_count == -1:
        device_count = xla_bridge.device_count()

    def broadcast(arg):
        # TODO(romann): remove this when JAX allows `axis_in` for `pmap`.
        return np.broadcast_to(arg, (device_count, ) + arg.shape)

    ker_fun = pmap(ker_fun)

    def parallel_fn(x1, x2=None, *args, **kwargs):
        if x2 is None:
            # TODO(schsam): Only compute the upper triangular part of the kernel.
            x2 = x1

        n1 = x1.shape[0]

        assert x1.shape[1:] == x2.shape[1:]
        input_shape = x1.shape[1:]

        _device_count = device_count

        n1_per_device, ragged = divmod(n1, device_count)
        if n1_per_device and ragged:
            raise ValueError(
                ('Dataset size ({}) must divide number of '
                 'physical devices ({}).').format(n1, device_count))
        elif not n1_per_device:
            _device_count = ragged
            n1_per_device = 1

        if n1_per_device:
            x1s = np.reshape(x1, (
                _device_count,
                n1_per_device,
            ) + input_shape)
        else:
            x1s = np.reshape(x1, (
                n1,
                1,
            ) + input_shape)

        x2s = broadcast(x2)
        args = tree_map(broadcast, args)
        kwargs = tree_map(broadcast, kwargs)

        kernel = ker_fun(x1s, x2s, *args, **kwargs)
        return _flatten_kernel(kernel)

    # Set function attributes so that `serial` can detect whether or not it is
    # acting on a parallel function.
    parallel_fn.is_parallel = True
    parallel_fn.device_count = device_count

    return parallel_fn
Esempio n. 16
0
 def test_pmap_error_no_receiver(self):
     # Check for errors if starting jit without a consumer active
     vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32)
     with self.assertRaisesRegex(ValueError,
                                 "outfeed_receiver is not started"):
         api.pmap(lambda x: hcb.id_print(x))(vargs)
Esempio n. 17
0
 def testIssue804(self):
     num_devices = xla_bridge.device_count()
     f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i"), c), 0.)
     api.pmap(f, axis_name="i")(np.ones((num_devices, 4)))  # doesn't crash
Esempio n. 18
0
 def testLogSoftmax(self):
     f = lambda x: x - np.log(lax.psum(np.exp(x), 'i'))
     x = onp.log(onp.arange(1., 10., dtype=onp.float32))
     ans = pmap(f, axis_name='i')(x)
     expected = x - onp.log(onp.sum(onp.exp(x)))
     self.assertAllClose(ans, expected, check_dtypes=False)