def test_pjit_inherits_effects(self): if jax.default_backend() not in {'gpu', 'tpu'}: raise unittest.SkipTest("pjit only supports GPU and TPU backends") def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('x'), out_axis_resources=pjit.PartitionSpec('x')) with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): with maps.Mesh(np.array(jax.devices()), ['x']): jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
def test_pjit_inherits_effects(self): def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('x'), out_axis_resources=pjit.PartitionSpec('x')) with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): with maps.Mesh(np.array(jax.devices()), ['x']): jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
def _matrix_inverse_pth_root_pjit(xs, ps): mesh_axis_names_tuple = tuple(mesh_axis_names) # Partition the concatenated statistics matrix across all cores. partitioned_xs, partitioned_ps = pjit.pjit( lambda x, y: (x, y), in_axis_resources=None, out_axis_resources=pjit.PartitionSpec(mesh_axis_names_tuple,))(xs, ps) # Run matrix inverse pth root on each shard. partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap( partitioned_xs, partitioned_ps) # Recombine the outputs at each core. preconditioners, errors = pjit.pjit( lambda x, y: (x, y), in_axis_resources=(pjit.PartitionSpec(mesh_axis_names_tuple,), pjit.PartitionSpec(mesh_axis_names_tuple,)), out_axis_resources=(None, None))(partitioned_preconditioners, partitioned_errors) return preconditioners, errors
def test_extract_prefixed_keys_from_state_specs(self): w_sepc = base_layer.var_partition_specs( {'w': py_utils.weight_params(shape=(4, 8))}, device_mesh=np.arange(1).reshape([1, 1]), device_axis_names=['a', 'b']) train_state_partition_specs = train_states.TrainState( step=pjit.PartitionSpec(), mdl_vars=w_sepc, opt_states={}) nested_names = py_utils.extract_prefixed_keys_from_nested_map( train_state_partition_specs) flattened_names, _ = jax.tree_util.tree_flatten(nested_names) self.assertListEqual(['step', 'mdl_vars/w'], flattened_names)