def jit_simple_pruned_args(n, state): args = [jax.device_put(i) for i in range(n)] f = jax.jit(lambda *xs: xs[0] + 1) x = f(*args) x.block_until_ready() while state: f(*args).block_until_ready() benchmarks = [] for n in [10, 100, 1000, 2000]: benchmarks += [ google_benchmark.register(partial(jit_simple_many_args_dispatch, n), name=f"jit_simple_many_args_dispatch_{n}"), google_benchmark.register(partial(jit_simple_many_args, n), name=f"jit_simple_many_args_{n}"), google_benchmark.register(partial(jit_simple_pruned_args_dispatch, n), name=f"jit_simple_pruned_args_dispatch_{n}"), google_benchmark.register(partial(jit_simple_pruned_args, n), name=f"jit_simple_pruned_args_{n}") ] @google_benchmark.register def jit_dispatch_without_transfer(state): # We pick up a realistic input. 224 is usual for classification and 128 a # TPU-friendly batch-size. imgs = np.ones((128, 224, 224), np.float32) imgs = jax.device_put(imgs)
def fun(a, b): return a @ b fun(a, b).block_until_ready() # ensure jit has finished while state: fun(a, b).block_until_ready() datasets = ("pubmed", "citeseer", "cora") # preload datasets to avoid spam later for data_name in datasets: load_data(data_name) for data_name in datasets: for dtype, dtype_str in ((jnp.float32, "f32"), (jnp.float64, "f64")): for backend in ("cpu", "gpu"): for fmt in "csr", "coo": benchmark.register( partial( matmul_benchmark, fmt=fmt, dtype=dtype, backend=backend, data_name=data_name, ), name="-".join((data_name, dtype_str, backend, fmt)), ) if __name__ == "__main__": benchmark.main()
def benchmark_model(model: Model, data: SplitData, seed: int = 0): import google_benchmark as benchmark # pylint: disable=import-outside-toplevel train_data = as_dataset(data.train_data).repeat() validation_data = as_dataset(data.validation_data).repeat() dummy_example = jax.tree_map(zeros_like, train_data.element_spec) model.compile(*dummy_example) rng = hk.PRNGSequence(seed) params, net_state, opt_state = model.init(next(rng), dummy_example[0]) train_step = model.compiled_train_step test_step = model.compiled_test_step metrics_state = model.init_metrics_state # pylint: disable=expression-not-assigned def train_benchmark(state): train_iter = iter(train_data) example = next(train_iter) params_, net_state_, opt_state_, metrics_state_, *_ = train_step( params, net_state, next(rng), opt_state, metrics_state, *example) [x.block_until_ready() for x in jax.tree_flatten(params_)[0]] while state: params_, net_state_, opt_state_, metrics_state_, *_ = train_step( params_, net_state_, next(rng), opt_state_, metrics_state_, *example) example = next(train_iter) [x.block_until_ready() for x in jax.tree_flatten(params_)[0]] def test_benchmark(state, data): metrics_state_ = metrics_state data_iter = iter(data) example = next(data_iter) metrics_state_, preds, loss, metrics = test_step( params, net_state, metrics_state, *example) [ x.block_until_ready() for x in jax.tree_flatten((metrics_state_, metrics, preds, loss))[0] ] while state: metrics_state_, preds, loss, metrics = test_step( params, net_state, metrics_state_, *example) example = next(data_iter) [ x.block_until_ready() for x in jax.tree_flatten((metrics_state_, metrics, preds, loss))[0] ] # pylint: enable=expression-not-assigned benchmark.register(train_benchmark, name="UNTRUSTWORTHY-train") benchmark.register(train_benchmark, name="train_benchmark") benchmark.register(partial(test_benchmark, data=validation_data), name="UNTRUSTWORTHY-validation") benchmark.register(partial(test_benchmark, data=validation_data), name="validation_benchmark") if data.test_data is not None: test_data = as_dataset(data.test_data).repeat() benchmark.register(partial(test_benchmark, data=test_data), name="test_benchmark") benchmark.main(argv=sys.argv[:1])
def indices_replica_id_calc_cached(mesh_shape, mesh_axes, state): global_input_shape = (2048, 2048) global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y")) while state: gda.get_shard_indices_replica_ids(global_input_shape, global_mesh, mesh_axes) benchmarks = [] for mesh_shape, axes in mesh_shapes_axes: benchmarks.extend([ google_benchmark.register( partial(gda_construction_callback, axes), name=f"gda_construction_callback_(4, 2)_{axes}"), google_benchmark.register( partial(gda_construction_raw, mesh_shape, axes), name=f"gda_construction_raw_{mesh_shape}_{axes}"), google_benchmark.register( partial(indices_replica_id_calc_uncached, mesh_shape, axes), name=f"indices_replica_id_calc_uncached_{mesh_shape}_{axes}"), google_benchmark.register( partial(indices_replica_id_calc_cached, mesh_shape, axes), name=f"indices_replica_id_calc_cached_{mesh_shape}_{axes}"), ]) if __name__ == "__main__": google_benchmark.main()