Exemplo n.º 1
0
def time_with_torch_timer(fn, args, string_id, kwargs={}):
    print("################################################")
    print(f"#### Torch Timer for {string_id} starts #########")
    print("################################################")
    ref = fn(*args, **kwargs)
    gO = torch.rand_like(ref)
    env = {"args": args, "gO": gO, "kwargs": kwargs, "fn": fn}
    grad_none = {"for x in args: x.grad=None"}
    fn_call = "fn(*args, **kwargs)"
    # Measure end-to-end fwd time
    timer = Timer(stmt=f"{fn_call}", globals=env)
    fwd_latency = round(timer.timeit(1000).mean * 10**6, 3)
    timer_blocked = timer.blocked_autorange()
    print(f"Forward = {fwd_latency}")

    # Measure end-to-end fwd bwd
    timer = Timer(
        stmt=f"{grad_none}; fwd = {fn_call}; fwd.backward(gO)",
        globals=env,
    )
    fwd_bwd_latency = round(timer.timeit(1000).mean * 10**6, 3)
    timer_blocked = timer.blocked_autorange()
    # print(f"Forward + sum + Backward = {fwd_sum_bwd_latency}")

    bwd_latency = round(fwd_bwd_latency - fwd_latency, 3)
    print(f"Backward = {bwd_latency}")

    print("################################################")
    print(f"#### Torch Timer for {string_id} ends ###############")
    print("################################################\n\n\n\n")
Exemplo n.º 2
0
# the vmap-vjp composition to compute jacobians. ``jacrev`` accepts an argnums
# argument that says which argument we would like to compute Jacobians with
# respect to.
from functorch import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
assert torch.allclose(ft_jacobian, jacobian)

# Let's compare the performance of the two ways to compute jacobian.
# The functorch version is much faster (and becomes even faster the more outputs
# there are). In general, we expect that vectorization via ``vmap`` can help
# eliminate overhead and give better utilization of your hardware.
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)",
                  globals=globals())
print(without_vmap.timeit(500))
print(with_vmap.timeit(500))

# It's pretty easy to flip the problem around and say we want to compute
# Jacobians of the parameters to our model (weight, bias) instead of the input.
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)

######################################################################
# reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd)
# --------------------------------------------------------------------
# We offer two APIs to compute jacobians: jacrev and jacfwd:
# - jacrev uses reverse-mode AD. As you saw above it is a composition of our
#   vjp and vmap transforms.
# - jacfwd uses forward-mode AD. It is implemented as a composition of our
#   jvp and vmap transforms.
# jacfwd and jacrev can be subsituted for each other and have different
Exemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser(prog="tensor_product_benchmark")
    parser.add_argument("--jit", type=t_or_f, default=True)
    parser.add_argument("--irreps",
                        type=str,
                        default="8x0e + 8x1e + 8x2e + 8x3o")
    parser.add_argument("--irreps-in1", type=str, default=None)
    parser.add_argument("--irreps-in2", type=str, default=None)
    parser.add_argument("--irreps-out", type=str, default=None)
    parser.add_argument("--cuda", type=t_or_f, default=True)
    parser.add_argument("--backward", type=t_or_f, default=True)
    parser.add_argument("--opt-ein", type=t_or_f, default=True)
    parser.add_argument("--specialized-code", type=t_or_f, default=True)
    parser.add_argument("--elementwise", action='store_true')
    parser.add_argument("-n", type=int, default=1000)
    parser.add_argument("--batch", type=int, default=10)

    args = parser.parse_args()

    device = 'cuda' if (torch.cuda.is_available() and args.cuda) else 'cpu'
    args.cuda = device == 'cuda'

    print("======= Benchmark with settings: ======")
    for key, val in vars(args).items():
        print(f"{key:>18} : {val}")
    print("=" * 40)

    irreps_in1 = Irreps(args.irreps_in1 if args.irreps_in1 else args.irreps)
    irreps_in2 = Irreps(args.irreps_in2 if args.irreps_in2 else args.irreps)
    irreps_out = Irreps(args.irreps_out if args.irreps_out else args.irreps)

    if args.elementwise:
        tp = ElementwiseTensorProduct(irreps_in1,
                                      irreps_in2,
                                      _specialized_code=args.specialized_code,
                                      _optimize_einsums=args.opt_ein)
        if args.backward:
            print(
                "Elementwise TP has no weights, cannot backward. Setting --backward False."
            )
            args.backward = False
    else:
        tp = FullyConnectedTensorProduct(
            irreps_in1,
            irreps_in2,
            irreps_out,
            _specialized_code=args.specialized_code,
            _optimize_einsums=args.opt_ein)
    tp = tp.to(device=device)
    assert len(tp.instructions) > 0, "Bad irreps, no instructions"
    print(f"Tensor product: {tp}")
    print("Instructions:")
    for ins in tp.instructions:
        print(f"  {ins}")

    # from https://pytorch.org/docs/master/_modules/torch/utils/benchmark/utils/timer.html#Timer.timeit
    warmup = max(int(args.n // 100), 1)

    inputs = iter([(irreps_in1.randn(args.batch, -1).to(device=device),
                    irreps_in2.randn(args.batch, -1).to(device=device))
                   for _ in range(args.n + warmup)])

    # compile
    if args.jit:
        tp = compile(tp)

    print("starting...")

    # tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones
    t = Timer(
        stmt=("tp.zero_grad()\n"
              "out = tp(*next(inputs))\n" +
              ("out.tanh().sum().backward()\n" if args.backward else '')),
        globals={
            'tp': tp,
            'inputs': inputs
        })

    perloop = t.timeit(args.n)

    print()
    print(perloop)