k = jax.random.PRNGKey(42).block_until_ready() while state: jax.eval_shape(init, k, x) @google_benchmark.register(name=f'{model.__name__}_init_fast') def init_fast_bench(state): """Benchmark runtime of compiled hk.init_fn of model.""" x = jnp.ones(input_shape) k = jax.random.PRNGKey(42).block_until_ready() while state: hk.experimental.fast_eval_shape(init, k, x) return init_slow_bench, init_fast_bench # Models to be benchmarked @init_benchmark def mlp(x): return hk.nets.MLP([300, 100, 10])(x) @init_benchmark def resnet_50(x): return hk.nets.ResNet50(num_classes=10)(x, is_training=True, test_local_stats=True) if __name__ == '__main__': google_benchmark.main()
@required_devices(2) def pmap_simple_2_devices(state): f = jax.pmap(lambda a, b: (a + b, a - b)) a, b = f(jnp.array([1, 2]), jnp.array([3, 4])) while state: c, d = f(a, b) c.block_until_ready() d.block_until_ready() @benchmark.register @required_devices(8) def pmap_simple_8_devices(state): f = jax.pmap(lambda a, b: (a + b, a - b)) a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]), jnp.array([2, 3, 4, 5, 6, 7, 8, 9])) while state: c, d = f(a, b) c.block_until_ready() d.block_until_ready() def swap(a, b): return b, a if __name__ == "__main__": benchmark.main()
def benchmark_model(model: Model, data: SplitData, seed: int = 0): import google_benchmark as benchmark # pylint: disable=import-outside-toplevel train_data = as_dataset(data.train_data).repeat() validation_data = as_dataset(data.validation_data).repeat() dummy_example = jax.tree_map(zeros_like, train_data.element_spec) model.compile(*dummy_example) rng = hk.PRNGSequence(seed) params, net_state, opt_state = model.init(next(rng), dummy_example[0]) train_step = model.compiled_train_step test_step = model.compiled_test_step metrics_state = model.init_metrics_state # pylint: disable=expression-not-assigned def train_benchmark(state): train_iter = iter(train_data) example = next(train_iter) params_, net_state_, opt_state_, metrics_state_, *_ = train_step( params, net_state, next(rng), opt_state, metrics_state, *example) [x.block_until_ready() for x in jax.tree_flatten(params_)[0]] while state: params_, net_state_, opt_state_, metrics_state_, *_ = train_step( params_, net_state_, next(rng), opt_state_, metrics_state_, *example) example = next(train_iter) [x.block_until_ready() for x in jax.tree_flatten(params_)[0]] def test_benchmark(state, data): metrics_state_ = metrics_state data_iter = iter(data) example = next(data_iter) metrics_state_, preds, loss, metrics = test_step( params, net_state, metrics_state, *example) [ x.block_until_ready() for x in jax.tree_flatten((metrics_state_, metrics, preds, loss))[0] ] while state: metrics_state_, preds, loss, metrics = test_step( params, net_state, metrics_state_, *example) example = next(data_iter) [ x.block_until_ready() for x in jax.tree_flatten((metrics_state_, metrics, preds, loss))[0] ] # pylint: enable=expression-not-assigned benchmark.register(train_benchmark, name="UNTRUSTWORTHY-train") benchmark.register(train_benchmark, name="train_benchmark") benchmark.register(partial(test_benchmark, data=validation_data), name="UNTRUSTWORTHY-validation") benchmark.register(partial(test_benchmark, data=validation_data), name="validation_benchmark") if data.test_data is not None: test_data = as_dataset(data.test_data).repeat() benchmark.register(partial(test_benchmark, data=test_data), name="test_benchmark") benchmark.main(argv=sys.argv[:1])