def testVmapOfPmapTuple(self): device_count = xla_bridge.device_count() f0 = lambda *x: x f1 = pmap(f0, axis_name='i') ax = onp.random.randn(device_count, 2, 50, 60) ay = onp.random.randn(device_count, 30, 2) az1 = onp.random.randn(device_count, 20) az2 = onp.random.randn(2, device_count, 20) bx, by, bz = vmap(f1, in_axes=(1, 2, (None, 0)), out_axes=(1, 2, 0))(ax, ay, (az1, az2)) self.assertAllClose(ax, bx, check_dtypes=False) self.assertAllClose(ay, by, check_dtypes=False) bz1, bz2 = bz expected_bz1 = onp.broadcast_to(az1, (2, ) + az1.shape) self.assertAllClose(expected_bz1, bz1, check_dtypes=False) self.assertAllClose(bz2, bz2, check_dtypes=False)
def testBadAxisSizeError(self): if xla_bridge.device_count() == 1: raise SkipTest("this test requires multiple devices") f = pmap(lambda x: lax.psum(x, 'i'), axis_name='i', devices=xla_bridge.devices()) with self.assertRaisesRegex( ValueError, r"Leading axis size of input to pmapped function must " r"equal the number of local devices passed to pmap. Got axis_size=1, " r"num_local_devices=\d."): f(np.ones(1)) with self.assertRaisesRegex( ValueError, r"Leading axis size of input to pmapped function must " r"equal the number of local devices passed to pmap. Got axis_size=\d, " r"num_local_devices=\d."): f(np.ones(xla_bridge.device_count() + 1))
def testCollectivesWithTreesOfDifferentDtypes(self): n = len(jax.devices()) x = { 'a': onp.arange(1 * n * n, 2 * n * n, dtype=onp.float32).reshape([n, n]), 'b': onp.arange(2 * n * n, 3 * n * n, dtype=onp.int32).reshape([n, n]), 'c': onp.arange(4 * n * n, 5 * n * n, dtype=onp.float32).reshape([n, n]), 'd': onp.arange(6 * n * n, 7 * n * n, dtype=onp.int32).reshape([n, n]) } tree_f = lambda f: partial(tree_util.tree_map, f) jax_f = lambda p: pmap(lambda x: p(x, 'i'), 'i') onp_f = lambda p: tree_f(lambda x: onp.broadcast_to(p(x, 0), x.shape)) assert_allclose = partial( tree_util.tree_multimap, partial(self.assertAllClose, check_dtypes=False)) assert_allclose(jax_f(lax.pmax)(x), onp_f(onp.max)(x)) assert_allclose(jax_f(lax.pmin)(x), onp_f(onp.min)(x)) assert_allclose(jax_f(lax.psum)(x), onp_f(onp.sum)(x)) assert_allclose(jax_f(lax.pmean)(x), onp_f(onp.mean)(x))
def f_pmapped(x, *args, **kwargs): args_np, args_np_idxs = [], [] args_other = {} # TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it. # https://github.com/google/jax/issues/912 # Filter out `np.ndarray`s from other arguments. for i, arg in enumerate(args): if _is_np_ndarray(arg): args_np.append(arg) args_np_idxs.append(i) else: args_other[i] = arg # Check cache before jitting. _key = key + tuple(args_other.items()) + tuple(kwargs.items()) if _key in cache: _f = cache[_key] else: # Define a `np.ndarray`-only function as a closure over other arguments. def _f(_x, *_args_np): # Merge args. _args_np = { i: _arg_np for i, _arg_np in zip(args_np_idxs, _args_np) } _args = _merge_dicts(_args_np, args_other) _args = tuple(v for k, v in sorted(_args.items())) return f(_x, *_args, **kwargs) _f = jit(_f) if device_count == 0 else pmap(_f) cache[_key] = _f # Broadcast `np.ndarray` arguments and apply the new function to them. args_np = tree_map(broadcast, args_np) return _f(x, *args_np)
def testNestedPmapConstant(self): if xla_bridge.device_count() == 1: raise SkipTest("this test requires multiple devices") f = pmap(pmap(lambda x: 3)) shape = (2, xla_bridge.device_count() // 2, 3) x = np.arange(prod(shape)).reshape(shape) ans = f(x) expected = 3 * onp.ones(shape[:2]) self.assertAllClose(ans, expected, check_dtypes=False) # Test that 'ans' was properly replicated across devices. expected_sharded = pmap(pmap(lambda x: x))(expected) self.assertEqual([b.device() for b in ans.device_buffers], [b.device() for b in expected_sharded.device_buffers]) f = pmap(pmap(lambda x: (x, 3))) x_sharded, ans = f(x) self.assertAllClose(ans, expected, check_dtypes=False) self.assertEqual([b.device() for b in ans.device_buffers], [b.device() for b in x_sharded.device_buffers])
def f(x): return api.pmap(np.mean)(x)
def g(z): return pmap(lambda x: x * y)(z)
def testShardedDeviceArrayBlockUntilReady(self): x = onp.arange(xla_bridge.device_count()) x = pmap(lambda x: x)(x) x.block_until_ready() # doesn't crash
def distributed_matrix_vector(x, y): """Matrix vector multiply. First batch it and then row by row""" fv = lambda z: lax.map(lambda j: vv(j, y), z) res = pmap(fv)(x.reshape((jax.device_count(), -1) + tuple(x.shape[1:]))) res = res.reshape(res.shape[0] * res.shape[1], *res.shape[2:]) return res
def pmvm(a, b): a = a.reshape((nrep, -1, a.shape[1])) func = pmap(lambda z: np.dot(z, b)) return func(a).reshape(b.shape)
def testMismatchedAxisSizes(self): n = xla_bridge.device_count() f = pmap(lambda x, y: x + y) jtu.check_raises_regexp( lambda: f(onp.random.randn(n), onp.random.randn(n - 1)), ValueError, "Axis size .* does not match leading dimension of shape .*")
def testReduceSum(self): f = lambda x: lax.psum(x, 'i') ans = pmap(f, axis_name='i')(onp.ones(4)) expected = 4 * onp.ones(4) self.assertAllClose(ans, expected, check_dtypes=False)
def testConstantFunction(self): f = lambda x: 3 ans = pmap(f, axis_name='i')(onp.ones(4)) expected = 3 * onp.ones(4) self.assertAllClose(ans, expected, check_dtypes=False)
def main(unused_argv): from jax.api import grad, jit, vmap, pmap, device_put "The following is required to use TPU Driver as JAX's backend." if FLAGS.TPU: config.FLAGS.jax_xla_backend = "tpu_driver" config.FLAGS.jax_backend_target = "grpc://" + os.environ[ 'TPU_ADDR'] + ':8470' TPU_ADDR = os.environ['TPU_ADDR'] ndevices = xla_bridge.device_count() if not FLAGS.TPU: ndevices = 1 pmap = partial(pmap, axis_name='i') """Setup some experiment parameters.""" meas_step = FLAGS.meas_step training_epochs = int(FLAGS.epochs) tmult = 1.0 if FLAGS.physical: tmult = FLAGS.lr if FLAGS.physicalL2: tmult = FLAGS.L2 * tmult if FLAGS.physical: training_epochs = 1 + int(FLAGS.epochs / tmult) print('Evolving for {:}e'.format(training_epochs)) losst = FLAGS.losst learning_rate = FLAGS.lr batch_size_per_device = FLAGS.bs N = FLAGS.N K = FLAGS.K batch_size = batch_size_per_device * ndevices steps_per_epoch = 50000 // batch_size training_steps = training_epochs * steps_per_epoch "Filename from FLAGS" filename = 'wrnL2_' + losst + '_n' + str(N) + '_k' + str(K) if FLAGS.momentum: filename += '_mom' if FLAGS.L2_sch: filename += '_L2sch' + '_decay' + str(FLAGS.L2dec) + '_del' + str( FLAGS.delay) if FLAGS.seed != 1: filename += 'seed' + str(FLAGS.seed) filename += '_L2' + str(FLAGS.L2) if FLAGS.std_wrn_sch: filename += '_stddec' if FLAGS.physical: filename += 'phys' else: filename += '_ctlr' if not FLAGS.augment: filename += '_noaug' if not FLAGS.mix: filename += '_nomixup' filename += '_bs' + str(batch_size) + '_lr' + str(learning_rate) if FLAGS.jobdir is not None: filedir = os.path.join('wrnlogs', FLAGS.jobdir) else: filedir = 'wrnlogs' if not os.path.exists(filedir): os.makedirs(filedir) filedir = os.path.join(filedir, filename + '.csv') print('Saving log to ', filename) print('Found {} cores.'.format(ndevices)) """Load CIFAR10 data and create a minimal pipeline.""" train_images, train_labels, test_images, test_labels = utils.load_data( 'cifar10') train_images = np.reshape(train_images, (-1, 32, 32 * 3)) train = (train_images, train_labels) test = (test_images, test_labels) k = train_labels.shape[-1] train = utils.shard_data(train, ndevices) test = utils.shard_data(test, ndevices) """Create a Wide Resnet and replicate its parameters across the devices.""" initparams, f, _ = utils.WideResnetnt(N, K, k) "Loss and optimizer definitions" l2_norm = lambda params: tree_map(lambda x: np.sum(x**2), params) l2_reg = lambda params: tree_reduce(lambda x, y: x + y, l2_norm(params)) currL2 = FLAGS.L2 L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, ))) def xentr(params, images_and_labels): images, labels = images_and_labels return -np.mean(stax.logsoftmax(f(params, images)) * labels) def mse(params, data_tuple): """MSE loss.""" x, y = data_tuple return 0.5 * np.mean((y - f(params, x))**2) if losst == 'xentr': print('Using xentr') lossm = xentr else: print('Using mse') lossm = mse loss = lambda params, data, L2: lossm(params, data) + L2 * l2_reg(params) def accuracy(params, images_and_labels): images, labels = images_and_labels return np.mean( np.array(np.argmax(f(params, images), axis=1) == np.argmax(labels, axis=1), dtype=np.float32)) "Define optimizer" if FLAGS.std_wrn_sch: lr = learning_rate first_epoch = int(60 / 200 * training_epochs) learning_rate_fn = optimizers.piecewise_constant( np.array([1, 2, 3]) * first_epoch * steps_per_epoch, np.array([lr, lr * 0.2, lr * 0.2**2, lr * 0.2**3])) else: learning_rate_fn = optimizers.make_schedule(learning_rate) if FLAGS.momentum: momentum = 0.9 else: momentum = 0 @pmap def update_step(step, state, batch_state, L2): batch, batch_state = batch_fn(batch_state) params = get_params(state) dparams = grad_loss(params, batch, L2) dparams = tree_map(lambda x: lax.psum(x, 'i') / ndevices, dparams) return step + 1, apply_fn(step, dparams, state), batch_state @pmap def evaluate(state, data, L2): params = get_params(state) lossmm = lossm(params, data) l2mm = l2_reg(params) return lossmm + L2 * l2mm, accuracy(params, data), lossmm, l2mm "Initialization and loading" _, params = initparams(random.PRNGKey(0), (-1, 32, 32, 3)) replicate_array = lambda x: \ np.broadcast_to(x, (ndevices,) + x.shape) replicated_params = tree_map(replicate_array, params) grad_loss = jit(grad(loss)) init_fn, apply_fn, get_params = optimizers.momentum( learning_rate_fn, momentum) apply_fn = jit(apply_fn) key = random.PRNGKey(FLAGS.seed) batchinit_fn, batch_fn = utils.sharded_minibatcher(batch_size, ndevices, transform=FLAGS.augment, k=k, mix=FLAGS.mix) batch_state = pmap(batchinit_fn)(random.split(key, ndevices), train) state = pmap(init_fn)(replicated_params) if FLAGS.checkpointing: ## Loading of checkpoint if available/provided. single_state = init_fn(params) i0, load_state, load_params, filename0, batch_stateb = utils.load_weights( filename, single_state, params, full_file=FLAGS.load_w, ndevices=ndevices) if i0 is not None: filename = filename0 if batch_stateb is not None: batch_state = batch_stateb if load_params is not None: state = pmap(init_fn)(load_params) else: state = load_state else: i0 = 0 else: i0 = 0 if FLAGS.steps_from_load: training_steps = i0 + training_steps batch_xs, _ = pmap(batch_fn)(batch_state) train_loss = [] train_accuracy = [] lrL = [] test_loss = [] test_accuracy = [] test_L2, test_lm, train_lm, train_L2 = [], [], [], [] L2_t = [] idel0 = i0 start = time.time() step = pmap(lambda x: x)(i0 * np.ones((ndevices, ))) "Start training loop" if FLAGS.checkpointing: print('Evolving for {:}e and saving every {:}s'.format( training_epochs, FLAGS.checkpointing)) print( 'Epoch\tLearning Rate\tTrain bareLoss\t L2_norm \tTest Loss\tTrain Error\tTest Error\tTime / Epoch' ) for i in range(i0, training_steps): if i % meas_step == 0: # Make Measurement l, a, lm, L2m = evaluate(state, test, L2p) test_loss += [np.mean(l)] test_accuracy += [np.mean(a)] test_lm += [np.mean(lm)] test_L2 += [np.mean(L2m)] train_batch, _ = pmap(batch_fn)(batch_state) l, a, lm, L2m = evaluate(state, train_batch, L2p) train_loss += [np.mean(l)] train_accuracy += [np.mean(a)] train_lm += [np.mean(lm)] train_L2 += [np.mean(L2m)] L2_t.append(currL2) lrL += [learning_rate_fn(i)] if FLAGS.L2_sch and i > FLAGS.delay / currL2 + idel0 and len( train_lm) > 2 and ((minloss <= train_lm[-1] and minloss <= train_lm[-2]) or (maxacc >= train_accuracy[-1] and maxacc >= train_accuracy[-2])): # If AutoL2 is on and we are beyond the refractory period, decay if the loss or error have increased in the last two measurements. print('Decaying L2 to', currL2 / FLAGS.L2dec) currL2 = currL2 / FLAGS.L2dec L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, ))) idel0 = i elif FLAGS.L2_sch and len(train_lm) >= 2: # Update the minimum values. try: maxacc = max(train_accuracy[-2], maxacc) minloss = min(train_lm[-2], minloss) except: maxacc, minloss = train_accuracy[-2], train_lm[-2] if i % (meas_step * 10) == 0 or i == i0: # Save measurements to csv epoch = batch_size * i / 50000 dt = (time.time() - start) / (meas_step * 10) * steps_per_epoch print(('{}\t' + ('{: .4f}\t' * 7)).format( epoch, learning_rate_fn(i), train_lm[-1], train_L2[-1], test_loss[-1], train_accuracy[-1], test_accuracy[-1], dt)) start = time.time() data = { 'train_loss': train_loss, 'test_loss': test_loss, 'train_acc': train_accuracy, 'test_acc': test_accuracy } data['train_bareloss'] = train_lm data['train_L2'] = train_L2 data['test_bareloss'] = test_lm data['test_L2'] = test_L2 data['L2_t'] = L2_t df = pd.DataFrame(data) df['learning_rate'] = lrL df['width'] = K df['batch_size'] = batch_size df['step'] = i0 + onp.arange(0, len(train_loss)) * meas_step df.to_csv(filedir, index=False) if FLAGS.checkpointing: ### SAVE MODEL if i % FLAGS.checkpointing == 0 and i > i0: if not os.path.exists('weights/'): os.makedirs('weights/') saveparams = tree_flatten(state[0])[0] if ndevices > 1: saveparams = [el[0] for el in saveparams] saveparams = np.concatenate( [el.reshape(-1) for el in saveparams]) step0 = i print('Step', i) print('saving at', filename, step0, 'size:', saveparams.shape) utils.save_weights(filename, step0, saveparams, batch_state) ## UPDATE step, state, batch_state = update_step(step, state, batch_state, L2p) print('Training done') if FLAGS.TPU: with open('done/' + TPU_ADDR, 'w') as fp: fp.write(filedir) pass
def _parallel(ker_fun, device_count=-1): """Returns a function that computes a kernel in batches in parallel. When batching in parallel, the data is split over a set number of devices. The number of devices must be less than or equal to the number of physical devices. Moreover, the dataset size needs to divide the device count. Given two datasets x1 and x2, parallel splits the kernel calculation over devices such that each device computes a batch of rows of shape [|x1| / device_count, |x2|]. Args: ker_fun: A function that computes a kernel between two datasets, ker_fun(x1, x2). Here x1 and x2 are `np.ndarray`s of floats of shape [n1,] + input_shape and [n2,] + input_shape. The kernel function should return a PyTree. device_count: Integer specifying the number of devices over which to split the data. If device_count = 0, the computation is parallelized over all available devices. Returns: A new function with the same signature as ker_fun that computes the kernel by batching over the dataset in parallel over a specified number of cores. """ if device_count == -1: device_count = xla_bridge.device_count() def broadcast(arg): # TODO(romann): remove this when JAX allows `axis_in` for `pmap`. return np.broadcast_to(arg, (device_count, ) + arg.shape) ker_fun = pmap(ker_fun) def parallel_fn(x1, x2=None, *args, **kwargs): if x2 is None: # TODO(schsam): Only compute the upper triangular part of the kernel. x2 = x1 n1 = x1.shape[0] assert x1.shape[1:] == x2.shape[1:] input_shape = x1.shape[1:] _device_count = device_count n1_per_device, ragged = divmod(n1, device_count) if n1_per_device and ragged: raise ValueError( ('Dataset size ({}) must divide number of ' 'physical devices ({}).').format(n1, device_count)) elif not n1_per_device: _device_count = ragged n1_per_device = 1 if n1_per_device: x1s = np.reshape(x1, ( _device_count, n1_per_device, ) + input_shape) else: x1s = np.reshape(x1, ( n1, 1, ) + input_shape) x2s = broadcast(x2) args = tree_map(broadcast, args) kwargs = tree_map(broadcast, kwargs) kernel = ker_fun(x1s, x2s, *args, **kwargs) return _flatten_kernel(kernel) # Set function attributes so that `serial` can detect whether or not it is # acting on a parallel function. parallel_fn.is_parallel = True parallel_fn.device_count = device_count return parallel_fn
def test_pmap_error_no_receiver(self): # Check for errors if starting jit without a consumer active vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32) with self.assertRaisesRegex(ValueError, "outfeed_receiver is not started"): api.pmap(lambda x: hcb.id_print(x))(vargs)
def testIssue804(self): num_devices = xla_bridge.device_count() f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i"), c), 0.) api.pmap(f, axis_name="i")(np.ones((num_devices, 4))) # doesn't crash
def testLogSoftmax(self): f = lambda x: x - np.log(lax.psum(np.exp(x), 'i')) x = onp.log(onp.arange(1., 10., dtype=onp.float32)) ans = pmap(f, axis_name='i')(x) expected = x - onp.log(onp.sum(onp.exp(x))) self.assertAllClose(ans, expected, check_dtypes=False)