def per_host_sum_pmap(in_tree): """Execute psum on in_tree"s leaves over one device per host.""" host2devices = collections.defaultdict(list) for d in jax.devices(): host2devices[d.process_index].append(d) devices = [host2devices[k][0] for k in host2devices] host_psum = jax.pmap(lambda x: jax.lax.psum(x, "i"), "i", devices=devices) def pre_pmap(xs): return jax.tree_map(lambda x: jnp.broadcast_to(x, (1, ) + x.shape), xs) def post_pmap(xs): return jax.tree_map(lambda x: x[0], xs) return post_pmap(host_psum(pre_pmap(in_tree)))
def test_with_kwargs(self, fake_pmap, fake_jit): with fake.fake_pmap_and_jit(fake_pmap, fake_jit): num_devices = len(jax.devices()) @functools.partial(jax.pmap, axis_size=num_devices) @jax.jit def foo(x, y): return (x * 2) + y # pmap over all available devices inputs = jnp.array([1, 2]) inputs = jnp.broadcast_to(inputs, (num_devices, ) + inputs.shape) expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2)) asserts.assert_tree_all_close(foo(x=inputs, y=inputs), expected)
def testCollectivesWithTreesOfDifferentDtypes(self): n = len(jax.devices()) x = {'a': onp.arange(1 * n * n, 2 * n * n, dtype=onp.float32).reshape([n, n]), 'b': onp.arange(2 * n * n, 3 * n * n, dtype=onp.int32).reshape([n, n]), 'c': onp.arange(4 * n * n, 5 * n * n, dtype=onp.float32).reshape([n, n]), 'd': onp.arange(6 * n * n, 7 * n * n, dtype=onp.int32).reshape([n, n])} tree_f = lambda f: partial(tree_util.tree_map, f) jax_f = lambda p: pmap(lambda x: p(x, 'i'), 'i') onp_f = lambda p: tree_f(lambda x: onp.broadcast_to(p(x, 0), x.shape)) assert_allclose = partial(tree_util.tree_multimap, partial(self.assertAllClose, check_dtypes=False)) assert_allclose(jax_f(lax.pmax)(x), onp_f(onp.max)(x)) assert_allclose(jax_f(lax.pmin)(x), onp_f(onp.min)(x)) assert_allclose(jax_f(lax.psum)(x), onp_f(onp.sum)(x)) assert_allclose(jax_f(lax.pmean)(x), onp_f(onp.mean)(x))
def multi_device_put(x, devices=None, reuse=True): """Memory efficient multi-device replication in JAX. Args: x: jax DeviceArray or numpy ndarray to be replicated. devices: a jax.devices() list or subset thereof of devices to replicate onto. Should match the list passed to any pmaps ingesting the replicated array. reuse: bool. If x is a DeviceArray whether to reuse its backing device_buffer in the resulting ShardedDeviceArray. Returns: A ShardedDeviceArray with dtype = x.dtype and shape = (n_devices,) + x.shape that's backed by replica device_buffers on each device. """ # Convert _FilledConstants that don't have device_buffer, etc. if type(x) != jax.xla.DeviceArray: # pylint: disable=unidiomatic-typecheck x = np.array(x) if not devices: devices = jax.devices() n_devices = len(devices) x_aval = jax.xla.abstractify(x) broadcast_x_aval = jax.abstract_arrays.ShapedArray( (n_devices,) + x_aval.shape, x_aval.dtype) if reuse: other_device_ordinals = [dv.id for dv in jax.devices() if dv != x.device_buffer.device()] broadcast_buffers = ([x.device_buffer,] + [jax.xla.xc.Buffer.from_pyval(x, device=i) for i in other_device_ordinals]) else: broadcast_buffers = [jax.xla.xc.Buffer.from_pyval(x, device=i) for i in range(n_devices)] return jax.pxla.ShardedDeviceArray(broadcast_x_aval, broadcast_buffers)
def __post_init__(self): super().__post_init__() n_chains_per_device = self.n_chains_per_rank // len(jax.devices()) _sampler = MetropolisSampler( self.hilbert, n_chains_per_rank=n_chains_per_device, rule=self.rule, n_sweeps=self.n_sweeps, reset_chains=self.reset_chains, machine_pow=self.machine_pow, ) object.__setattr__(self, "_sampler_device", _sampler)
def _to_dev(x, dev): if dev is not None: if 'cpu' in dev or 'gpu' in dev: dev_split = dev.split(':') dev_str = dev_split[0] if len(dev_split) > 1: idx = int(dev_split[1]) else: idx = 0 _jax.device_put(x, _jax.devices(dev_str)[idx]) else: raise Exception( 'Invalid device specified, must be in the form [ "cpu:idx" | "gpu:idx" ]' ) return x
def test_unordered_print_with_xmap(self): def f(x): debug_print("{}", x, ordered=False) f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu', axis_resources={'a': 'dev'}) with maps.Mesh(np.array(jax.devices(backend='cpu')), ['dev']): with capture_stdout() as output: f(jnp.arange(40)) jax.effects_barrier() lines = [f"{i}\n" for i in range(40)] self._assertLinesEqual(output(), "".join(lines))
def __pre_init__(self, *args, n_chains=None, n_chains_per_device=None, **kwargs): """ Constructs a Metropolis Sampler. Args: hilbert: The hilbert space to sample rule: A `MetropolisRule` to generate random transitions from a given state as well as uniform random states. n_sweeps: The number of exchanges that compose a single sweep. If None, sweep_size is equal to the number of degrees of freedom being sampled (the size of the input vector s to the machine). reset_chains: If False the state configuration is not resetted when reset() is called. n_chains: The total number of Markov Chain to be run in parallel on a the available devices. This will be rounded to the nearest multiple of `len(jax.devices())` n_chains_per_device: The number of chains to be run in parallel on one device. Cannot be specified if n_chains is also specified. machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2). dtype: The dtype of the statees sampled (default = np.float32). """ if n_chains is not None and n_chains_per_device is not None: raise ValueError( "Cannot specify both n_chains and n_chains_per_device") elif n_chains is None and n_chains_per_device is None: n_chains = 16 n_devices = len(jax.devices()) # If chains is specified, round it if n_chains is not None: n_chains_per_device = int(max(np.ceil(n_chains / n_devices), 1)) if n_chains != n_chains_per_device * n_devices: import warnings warnings.warn( f"Using {n_chains_per_device*n_devices} chains " f"({n_chains_per_device} chains on each of {n_devices} devices).", category=UserWarning, ) kwargs["n_chains"] = n_chains_per_device * n_devices print("he") return super().__pre_init__(*args, **kwargs)
def local_replica_groups(inner_group_size: int) -> List[List[int]]: """Constructs local nearest-neighbor rings given the JAX device assignment. For inner_group_size=8, each inner group is a tray with replica order: 0/1 2/3 7/6 5/4 Args: inner_group_size: Number of replica in each group. Returns: A list of replica id groups. """ world_size = jax.device_count() outer_group_size, ragged = divmod(world_size, inner_group_size) assert not ragged, 'inner group size must evenly divide global device count' # the last device should have maximal x and y coordinate def bounds_from_last_device(device): x, y, z = device.coords return (x + 1) * (device.core_on_chip + 1), (y + 1) * (z + 1) global_x, _ = bounds_from_last_device(jax.devices()[-1]) per_host_x, per_host_y = bounds_from_last_device(jax.local_devices(0)[-1]) assert inner_group_size in [2 ** i for i in range(1, 15)], \ 'inner group size must be a power of two' if inner_group_size <= 4: # inner group is Nx1 (core, chip, 2x1) inner_x, inner_y = inner_group_size, 1 inner_perm = range(inner_group_size) else: if inner_group_size <= global_x * 2: # inner group is Nx2 (2x2 tray, 4x2 DF pod host, row of hosts) inner_x, inner_y = inner_group_size // 2, 2 else: # inner group covers the full x dimension and must be >2 in y inner_x, inner_y = global_x, inner_group_size // global_x p = np.arange(inner_group_size) per_group_hosts_x = 1 if inner_x < per_host_x else inner_x // per_host_x p = p.reshape(inner_y // per_host_y, per_group_hosts_x, per_host_y, inner_x // per_group_hosts_x) p = p.transpose(0, 2, 1, 3) p = p.reshape(inner_y // 2, 2, inner_x) p[:, 1, :] = p[:, 1, ::-1] inner_perm = p.reshape(-1) inner_replica_groups = [[o * inner_group_size + i for i in inner_perm] for o in range(outer_group_size)] return inner_replica_groups
def create_hybrid_device_mesh(mesh_shape: Sequence[int], dcn_mesh_shape: Sequence[int], devices: Optional[Sequence[Any]] = None, *, process_is_granule: bool = False) -> np.ndarray: """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism. Args: mesh_shape: shape of the logical mesh for the faster/inner network, ordered by increasing network intensity, e.g. [replica, data, mdl] where mdl has the most network communication requirements. dcn_mesh_shape: shape of the logical mesh for the slower/outer network, in the same order as mesh_shape. devices: optionally, the devices to construct a mesh for. Defaults to jax.devices(). process_is_granule: if True, this function will treat processes as the units of the slower/outer network. Otherwise it will look for slice_index attributes on devices and use slices as the units. Enabling this is meant as a fallback for platforms (e.g., GPU) that don't set slice_index. Returns: A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape that can be fed into jax.experimental.maps.Mesh for hybrid parallelism. """ if devices is None: devices = jax.devices() attr = 'process_index' if process_is_granule else 'slice_index' assert hasattr(devices[0], attr) granule_id, granules = 0, [] while True: granule = [dev for dev in devices if getattr(dev, attr) == granule_id] if granule: granules.append(granule) granule_id += 1 else: break if np.prod(dcn_mesh_shape) != len(granules): raise ValueError( 'Number of slices must equal the product of dcn_mesh_shape') per_granule_meshes = [ create_device_mesh(mesh_shape, granule) for granule in granules ] # TODO(jekbradbury): handle non-uniform DCN topologies granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape) blocks = np.vectorize(lambda i: per_granule_meshes[i], otypes=[object])(granule_mesh) device_mesh = np.block(blocks.tolist()) return device_mesh
def _flatten_HeteroGraphIndex(gidx): adj, _ = gidx.adjacency_matrix( etype=0, transpose=False, ctx=jax.devices('cpu')[0], ) srctype, dsttype = gidx.metagraph.find_edge(0) num_src = gidx.number_of_nodes(srctype) num_dst = gidx.number_of_nodes(dsttype) idx = adj.index u = idx[0, :] v = idx[1, :] return ((u, v, num_src, num_dst), None)
def test_pjit_inherits_effects(self): if jax.default_backend() not in {'gpu', 'tpu'}: raise unittest.SkipTest("pjit only supports GPU and TPU backends") def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('x'), out_axis_resources=pjit.PartitionSpec('x')) with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): with maps.Mesh(np.array(jax.devices()), ['x']): jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) if gpu and jax.default_backend() == "cpu": raise unittest.SkipTest("Skipping GPU test case on CPU") device = jax.devices("gpu" if gpu else "cpu")[0] x = jax.device_put(np, device) dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership) self.assertEqual(take_ownership, x.device_buffer.is_deleted()) y = jax.dlpack.from_dlpack(dlpack) self.assertEqual(y.device(), device) self.assertAllClose(np.astype(x.dtype), y) self.assertRaisesRegex(RuntimeError, "DLPack tensor may be consumed at most once", lambda: jax.dlpack.from_dlpack(dlpack))
def test_fake_pmap(self, context, patch, expected_type): # We test whether the function has been pmapped by inspecting the type of # the function output, if it is a sharded array type then the function has # been pmapped with context(patch): num_devices = len(jax.devices()) @functools.partial(jax.pmap, axis_size=num_devices) def foo(x): return x * 2 # pmap over all available devices x = jnp.array([1, 2]) x = jnp.broadcast_to(x, (num_devices,) + x.shape) output = foo(x) self.assertEqual(type(output), expected_type)
def add_rng_to_examples( example_iter, base_rng): """Add an RNG to each example. Args: example_iter: Iterator over examples. base_rng: RNG to seed with. Yields: Examples that are tuples (orig_example, rng) """ base_rng = jax.device_put(base_rng, jax.devices("cpu")[0]) for i, item in enumerate(example_iter): rng = jax.random.fold_in(base_rng, i) yield dataclasses.replace(item, example=(item.example, rng))
def _jax_gsddmm(gidx, op, X, Y, lhs_target, rhs_target): # out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target) # return out if gidx.number_of_etypes() != 1: from .base import DGLError raise DGLError("We only support gsddmm on graph with one edge type") use_lhs = op != 'copy_rhs' use_rhs = op != 'copy_lhs' expand_lhs, expand_rhs = False, False # deal with scalar features. if use_lhs: if X.ndim == 1: X = jnp.expand_dims(X, -1) expand_lhs = True if use_rhs: if Y.ndim == 1: Y = jnp.expand_dims(Y, -1) expand_rhs = True a, _ = gidx.adjacency_matrix(0, False, jax.devices('cpu')[0]) dst_idxs, src_idxs = a.index edge_idxs = jnp.arange(dst_idxs.shape[0]) idxs_mapping = { 'u': src_idxs, 'e': edge_idxs, 'v': dst_idxs, 'src': src_idxs, 'edge': edge_idxs, 'dst': dst_idxs, } if X is not None: _X = jnp.take(X, idxs_mapping[lhs_target], axis=0) else: _X = None if Y is not None: _Y = jnp.take(Y, idxs_mapping[rhs_target], axis=0) else: _Y = None Z = OPS[op](_X, _Y) if (expand_lhs or not use_lhs) and (expand_rhs or not use_rhs): Z = jnp.expand_dims(Z, -1) return Z
def test_pmap_basic(self): if len(jax.devices()) < 2: raise unittest.SkipTest("requires at least 2 devices") @jax.pmap def f(x1, x2): y1 = jnp.sin(x1) y2 = jnp.sin(x2) return y1 + y2 xs = jnp.array([0., 2.]) err, _ = checkify.checkify(f)(xs, xs) self.assertIs(err.get(), None) ys = jnp.array([3., jnp.inf]) err, _ = checkify.checkify(f)(xs, ys) self.assertStartsWith(err.get(), 'nan generated by primitive sin')
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') logging.info('JAX devices: %s', jax.devices()) grids = np.arange(-256, 257) * 0.08 external_potential = utils.get_atomic_chain_potential( grids=grids, locations=np.array([-0.8, 0.8]), nuclear_charges=np.array([1., 1.]), interaction_fn=utils.exponential_coulomb) density, total_eigen_energies, _ = scf.solve_noninteracting_system( external_potential, num_electrons=FLAGS.num_electrons, grids=grids) logging.info('density: %s', density) logging.info('total energy: %f', total_eigen_energies)
def wrapper(*args: pytypes.ArrayTree, **kwargs: pytypes.ArrayTree): if kwargs and (in_axes != 0 or static_argnums): raise ValueError("Do not use kwargs with `in_axes` or `static_argnums` " "in pmapped function.") devices_ = list(devices or jax.devices(backend)) n_devices_ = n_devices or len(devices_) devices_ = devices_[:n_devices_] if len(devices_) != n_devices_: raise ValueError("Number of available devices is less than required for " f"test ({len(devices_)} < {n_devices_})") bcast_fn = lambda x: jnp.broadcast_to(x, (n_devices_,) + jnp.array(x).shape) if broadcast_args_to_devices: args = [ tree_map(bcast_fn, arg) if idx not in static_argnums else arg for idx, arg in enumerate(args) ] kwargs = tree_map(bcast_fn, kwargs) else: # Pmappable axes size must be equal to number of devices. in_axes_ = in_axes if isinstance(in_axes, (tuple, list)) else [in_axes] * len(args) is_pmappable_arg = [ idx not in static_argnums and in_axes_[idx] is not None for idx in range(len(args)) ] for is_pmappable_arg, arg in zip(is_pmappable_arg, args): if not is_pmappable_arg: continue if not all(x.shape[0] == n_devices_ for x in jax.tree_leaves(arg)): shapes = tree_map(jnp.shape, arg) raise ValueError( f"Pmappable arg axes size must be equal to number of devices, " f"got: {shapes} (expected the first dim to be {n_devices_}). " "Consider setting `broadcast_args_to_devices=True`.") res = jax.pmap( fn, axis_name=axis_name, devices=devices_, in_axes=in_axes, static_broadcasted_argnums=static_argnums, backend=backend)(*args, **kwargs) return reduce_fn(res)
def localfoe(signal, frame_size=16384, frame_step=5000, sps=1, fitkind=None, degree=2, method=lambda x: foe_mpowfftmax(x)[0]): ''' resolution = samplerate / N / 4 / sps (linear interp.) ''' cpus = jax.devices('cpu') # [BUG]: polyfit is buggy in GPU y = device_put(signal, cpus[0]) dims = y.shape[-1] fo_local = xop.framescaninterp(y, method, frame_size, frame_step, sps) if fitkind is not None: if fitkind.lower() == 'poly': fo_T = jnp.tile(jnp.arange(fo_local.shape[0])[:, None], (1, dims)) fo_local = exp.polyfitval(fo_T, fo_local, degree) else: raise ValueError('invlaid fitting method') return fo_local
def load_pickle(inname, metric=None, device=jax.devices()[0]): loaded_som = pickle.load(open(inname, 'rb')) if metric is None: print( 'WARNING : loading a JAX SOM object without its metric results in the cdist metric being used ' 'because jitted functions can not be pickled. If you want to use another metric, please ' 'specify it in the SOM.load_pickle function') metric = jax_cdist loaded_som.device = device loaded_som.metric = metric # We need to manually choose which arrays are to be put, otherwise every numpy item ends up a jax object. loaded_som.centroids = jax.device_put(loaded_som.centroids, device=device) loaded_som.locations = jax.device_put(loaded_som.locations, device=device) loaded_som.distance_mat = jax.device_put(loaded_som.distance_mat, device=device) return loaded_som
def main(argv): del argv # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") logging.info("JAX host: %d / %d", jax.host_id(), jax.host_count()) logging.info("JAX devices: %r", jax.devices()) # Add a note so that we can tell which task is which JAX host. # (Borg task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f"host_id: {jax.host_id()}, host_count: {jax.host_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") state = train_and_evaluate(FLAGS.config, FLAGS.workdir) del state
def learner( self, random_key: networks_lib.PRNGKey, replay: reverb.Client, counter: counting.Counter, ): """The Learning part of the agent.""" iterator = self._builder.make_dataset_iterator(replay) dummy_seed = 1 environment_spec = (self._environment_spec or specs.make_environment_spec( self._environment_factory(dummy_seed))) # Creates the networks to optimize (online) and target networks. networks = self._network_factory(environment_spec) if self._prefetch_size > 1: # When working with single GPU we should prefetch to device for # efficiency. If running on TPU this isn't necessary as the computation # and input placement can be done automatically. For multi-gpu currently # the best solution is to pre-fetch to host although this may change in # the future. device = jax.devices()[0] if self._device_prefetch else None iterator = utils.prefetch(iterator, buffer_size=self._prefetch_size, device=device) else: logging.info('Not prefetching the iterator.') counter = counting.Counter(counter, 'learner') learner = self._builder.make_learner(random_key, networks, iterator, replay, counter) return savers.CheckpointingRunner( learner, key='learner', subdirectory='learner', time_delta_minutes=5, directory=self._checkpointing_config.directory, add_uid=self._checkpointing_config.add_uid, max_to_keep=self._checkpointing_config.max_to_keep)
def main(): args = parser.parse_args() print('JAX host: %d / %d' % (jax.host_id(), jax.host_count())) print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True) def _try_validate(args): res = None batch_size = args.batch_size while res is None: try: print(f'Setting validation batch size to {batch_size}') args.batch_size = batch_size res = validate(args) except RuntimeError as e: if batch_size <= 1: print("Validation failed with no ability to reduce batch size. Exiting.") raise e batch_size = max(batch_size // 2, 1) print("Validation failed, reducing batch size by 50%") return res if get_model_cfg(args.model) is not None: _try_validate(args) else: models = list_models(pretrained=True) if args.model != 'all': models = fnmatch.filter(models, args.model) if not models: print(f'ERROR: No models found to validate with pattern {args.model}.') exit(1) print('Validating:', ', '.join(models)) results = [] start_batch_size = args.batch_size for m in models: args.batch_size = start_batch_size # reset in case reduced for retry args.model = m res = _try_validate(args) res.update(dict(model=m)) results.append(res) print('Results:') for r in results: print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")
def train(model, g, feats, y_true, train_idx, optimizer): g = g.to(jax.devices()[0]) @jax.jit def loss_fn(param, y_true=y_true): out = model.apply(param, g, feats)[train_idx] y_true = y_true[train_idx].flatten() y_true = jax.nn.one_hot(y_true, 40) loss = jnp.mean(-out * y_true) return loss # grad = jax.jacfwd(loss_fn)(optimizer.target) # loss = loss_fn(optimizer.target) loss, grad = jax.value_and_grad(loss_fn)(optimizer.target) optimizer = optimizer.apply_gradient(grad) return optimizer, loss
def _pjit(inp): if isinstance(inp, GlobalDeviceArray): if inp.is_fully_replicated: return inp.local_data(0).to_py() global_mesh = inp.mesh in_axis_resources = FROM_GDA else: # DA/SDA/np.array will be sharded based on global_mesh.local_mesh. # Shape of local_mesh will always be (1, local_device_count()) devices = np.array(jax.devices()).reshape(jax.process_count(), jax.local_device_count()) global_mesh = maps.Mesh(devices, ('processes', 'local_devices')) in_axis_resources = P('processes') if inp.ndim == 0 or not tiled: inp = np.expand_dims(inp, axis=0) with maps.Mesh(global_mesh.devices, global_mesh.axis_names): out = pjit(lambda x: x, in_axis_resources=in_axis_resources, out_axis_resources=None)(inp) return out.local_data(0).to_py()
def init_fn(master_rng, data, init_fn, optimizer): out_rng, init_rng = jax.random.split(master_rng) # copy the same initial params to each accelerator init_rng = jnp.broadcast_to(init_rng, (jax.local_device_count(),) + init_rng.shape) params = jax.pmap(init_fn)(init_rng, data) cpu_device = jax.devices("cpu")[0] # place optimizer state on CPU cpu_params = jax.tree_map(lambda x: jax.device_put(x[0], device=cpu_device), params) opt_state = optimizer.init(cpu_params) return dict( step=np.array(0), rng=out_rng, opt_state=opt_state, grad_acc=jax.tree_map(jnp.zeros_like, params), grad_count=np.array(0), params=params)
def generate_graph(idtype, grad=False): ''' s, d, eid 0, 1, 0 1, 9, 1 0, 2, 2 2, 9, 3 0, 3, 4 3, 9, 5 0, 4, 6 4, 9, 7 0, 5, 8 5, 9, 9 0, 6, 10 6, 9, 11 0, 7, 12 7, 9, 13 0, 8, 14 8, 9, 15 9, 0, 16 ''' u = F.tensor([0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 9]) v = F.tensor([1, 9, 2, 9, 3, 9, 4, 9, 5, 9, 6, 9, 7, 9, 8, 9, 0]) g = dgl.graph((u, v), idtype=idtype) assert g.device == F.ctx() ncol = F.randn((10, D)) ecol = F.randn((17, D)) if grad: ncol = F.attach_grad(ncol) ecol = F.attach_grad(ecol) g.ndata['h'] = ncol g.edata['w'] = ecol g.set_n_initializer(dgl.init.zero_initializer) g.set_e_initializer(dgl.init.zero_initializer) if dgl.backend.backend_name == "jax": import jax g = g.to(jax.devices()[0], ) return g
def test_with_partial(self, fake_pmap, fake_jit): with fake.fake_pmap_and_jit(fake_pmap, fake_jit): num_devices = len(jax.devices()) # Testing a common use-case where non-parallel arguments are partially # applied before pmapping def foo(x, y, flag): return (x * 2) + y if flag else (x + y) foo = functools.partial(foo, flag=True) foo = jax.pmap(foo, axis_size=num_devices) foo = jax.jit(foo) # pmap over all available devices inputs = jnp.array([1, 2]) inputs = jnp.broadcast_to(inputs, (num_devices, ) + inputs.shape) expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2)) asserts.assert_tree_all_close(foo(inputs, inputs), expected) asserts.assert_tree_all_close(foo(x=inputs, y=inputs), expected)
def test_pmap_and_jit(self, context, fake_pmap, fake_jit, expected_type, expected_execution_count): python_execution_count = 0 with context(fake_pmap, fake_jit): num_devices = len(jax.devices()) @functools.partial(jax.pmap, axis_size=num_devices) @jax.jit def foo(x): nonlocal python_execution_count python_execution_count += 1 return x * 2 # pmap over all available devices inputs = jnp.array([1, 2]) inputs = jnp.broadcast_to(inputs, (num_devices,) + inputs.shape) output = foo(inputs) self.assertEqual(type(output), expected_type) self.assertEqual(python_execution_count, 1) foo(inputs) self.assertEqual(python_execution_count, expected_execution_count)