Exemplo n.º 1
0

def jit_simple_pruned_args(n, state):
    args = [jax.device_put(i) for i in range(n)]
    f = jax.jit(lambda *xs: xs[0] + 1)
    x = f(*args)
    x.block_until_ready()

    while state:
        f(*args).block_until_ready()


benchmarks = []
for n in [10, 100, 1000, 2000]:
    benchmarks += [
        google_benchmark.register(partial(jit_simple_many_args_dispatch, n),
                                  name=f"jit_simple_many_args_dispatch_{n}"),
        google_benchmark.register(partial(jit_simple_many_args, n),
                                  name=f"jit_simple_many_args_{n}"),
        google_benchmark.register(partial(jit_simple_pruned_args_dispatch, n),
                                  name=f"jit_simple_pruned_args_dispatch_{n}"),
        google_benchmark.register(partial(jit_simple_pruned_args, n),
                                  name=f"jit_simple_pruned_args_{n}")
    ]


@google_benchmark.register
def jit_dispatch_without_transfer(state):
    # We pick up a realistic input. 224 is usual for classification and 128 a
    # TPU-friendly batch-size.
    imgs = np.ones((128, 224, 224), np.float32)
    imgs = jax.device_put(imgs)
Exemplo n.º 2
0
    def fun(a, b):
        return a @ b

    fun(a, b).block_until_ready()  # ensure jit has finished
    while state:
        fun(a, b).block_until_ready()


datasets = ("pubmed", "citeseer", "cora")
# preload datasets to avoid spam later
for data_name in datasets:
    load_data(data_name)
for data_name in datasets:
    for dtype, dtype_str in ((jnp.float32, "f32"), (jnp.float64, "f64")):
        for backend in ("cpu", "gpu"):
            for fmt in "csr", "coo":
                benchmark.register(
                    partial(
                        matmul_benchmark,
                        fmt=fmt,
                        dtype=dtype,
                        backend=backend,
                        data_name=data_name,
                    ),
                    name="-".join((data_name, dtype_str, backend, fmt)),
                )


if __name__ == "__main__":
    benchmark.main()
Exemplo n.º 3
0
def benchmark_model(model: Model, data: SplitData, seed: int = 0):
    import google_benchmark as benchmark  # pylint: disable=import-outside-toplevel

    train_data = as_dataset(data.train_data).repeat()
    validation_data = as_dataset(data.validation_data).repeat()

    dummy_example = jax.tree_map(zeros_like, train_data.element_spec)
    model.compile(*dummy_example)
    rng = hk.PRNGSequence(seed)
    params, net_state, opt_state = model.init(next(rng), dummy_example[0])
    train_step = model.compiled_train_step
    test_step = model.compiled_test_step
    metrics_state = model.init_metrics_state

    # pylint: disable=expression-not-assigned
    def train_benchmark(state):

        train_iter = iter(train_data)
        example = next(train_iter)

        params_, net_state_, opt_state_, metrics_state_, *_ = train_step(
            params, net_state, next(rng), opt_state, metrics_state, *example)

        [x.block_until_ready() for x in jax.tree_flatten(params_)[0]]
        while state:
            params_, net_state_, opt_state_, metrics_state_, *_ = train_step(
                params_, net_state_, next(rng), opt_state_, metrics_state_,
                *example)
            example = next(train_iter)
            [x.block_until_ready() for x in jax.tree_flatten(params_)[0]]

    def test_benchmark(state, data):
        metrics_state_ = metrics_state

        data_iter = iter(data)
        example = next(data_iter)
        metrics_state_, preds, loss, metrics = test_step(
            params, net_state, metrics_state, *example)
        [
            x.block_until_ready()
            for x in jax.tree_flatten((metrics_state_, metrics, preds,
                                       loss))[0]
        ]
        while state:
            metrics_state_, preds, loss, metrics = test_step(
                params, net_state, metrics_state_, *example)
            example = next(data_iter)
            [
                x.block_until_ready()
                for x in jax.tree_flatten((metrics_state_, metrics, preds,
                                           loss))[0]
            ]

    # pylint: enable=expression-not-assigned
    benchmark.register(train_benchmark, name="UNTRUSTWORTHY-train")
    benchmark.register(train_benchmark, name="train_benchmark")
    benchmark.register(partial(test_benchmark, data=validation_data),
                       name="UNTRUSTWORTHY-validation")
    benchmark.register(partial(test_benchmark, data=validation_data),
                       name="validation_benchmark")
    if data.test_data is not None:
        test_data = as_dataset(data.test_data).repeat()

        benchmark.register(partial(test_benchmark, data=test_data),
                           name="test_benchmark")

    benchmark.main(argv=sys.argv[:1])
Exemplo n.º 4
0

def indices_replica_id_calc_cached(mesh_shape, mesh_axes, state):
    global_input_shape = (2048, 2048)
    global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y"))

    while state:
        gda.get_shard_indices_replica_ids(global_input_shape, global_mesh,
                                          mesh_axes)


benchmarks = []
for mesh_shape, axes in mesh_shapes_axes:
    benchmarks.extend([
        google_benchmark.register(
            partial(gda_construction_callback, axes),
            name=f"gda_construction_callback_(4, 2)_{axes}"),
        google_benchmark.register(
            partial(gda_construction_raw, mesh_shape, axes),
            name=f"gda_construction_raw_{mesh_shape}_{axes}"),
        google_benchmark.register(
            partial(indices_replica_id_calc_uncached, mesh_shape, axes),
            name=f"indices_replica_id_calc_uncached_{mesh_shape}_{axes}"),
        google_benchmark.register(
            partial(indices_replica_id_calc_cached, mesh_shape, axes),
            name=f"indices_replica_id_calc_cached_{mesh_shape}_{axes}"),
    ])

if __name__ == "__main__":
    google_benchmark.main()