def time_with_torch_timer(fn, args, string_id, kwargs={}): print("################################################") print(f"#### Torch Timer for {string_id} starts #########") print("################################################") ref = fn(*args, **kwargs) gO = torch.rand_like(ref) env = {"args": args, "gO": gO, "kwargs": kwargs, "fn": fn} grad_none = {"for x in args: x.grad=None"} fn_call = "fn(*args, **kwargs)" # Measure end-to-end fwd time timer = Timer(stmt=f"{fn_call}", globals=env) fwd_latency = round(timer.timeit(1000).mean * 10**6, 3) timer_blocked = timer.blocked_autorange() print(f"Forward = {fwd_latency}") # Measure end-to-end fwd bwd timer = Timer( stmt=f"{grad_none}; fwd = {fn_call}; fwd.backward(gO)", globals=env, ) fwd_bwd_latency = round(timer.timeit(1000).mean * 10**6, 3) timer_blocked = timer.blocked_autorange() # print(f"Forward + sum + Backward = {fwd_sum_bwd_latency}") bwd_latency = round(fwd_bwd_latency - fwd_latency, 3) print(f"Backward = {bwd_latency}") print("################################################") print(f"#### Torch Timer for {string_id} ends ###############") print("################################################\n\n\n\n")
# the vmap-vjp composition to compute jacobians. ``jacrev`` accepts an argnums # argument that says which argument we would like to compute Jacobians with # respect to. from functorch import jacrev ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x) assert torch.allclose(ft_jacobian, jacobian) # Let's compare the performance of the two ways to compute jacobian. # The functorch version is much faster (and becomes even faster the more outputs # there are). In general, we expect that vectorization via ``vmap`` can help # eliminate overhead and give better utilization of your hardware. from torch.utils.benchmark import Timer without_vmap = Timer(stmt="compute_jac(xp)", globals=globals()) with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) print(without_vmap.timeit(500)) print(with_vmap.timeit(500)) # It's pretty easy to flip the problem around and say we want to compute # Jacobians of the parameters to our model (weight, bias) instead of the input. ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x) ###################################################################### # reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd) # -------------------------------------------------------------------- # We offer two APIs to compute jacobians: jacrev and jacfwd: # - jacrev uses reverse-mode AD. As you saw above it is a composition of our # vjp and vmap transforms. # - jacfwd uses forward-mode AD. It is implemented as a composition of our # jvp and vmap transforms. # jacfwd and jacrev can be subsituted for each other and have different
def main(): parser = argparse.ArgumentParser(prog="tensor_product_benchmark") parser.add_argument("--jit", type=t_or_f, default=True) parser.add_argument("--irreps", type=str, default="8x0e + 8x1e + 8x2e + 8x3o") parser.add_argument("--irreps-in1", type=str, default=None) parser.add_argument("--irreps-in2", type=str, default=None) parser.add_argument("--irreps-out", type=str, default=None) parser.add_argument("--cuda", type=t_or_f, default=True) parser.add_argument("--backward", type=t_or_f, default=True) parser.add_argument("--opt-ein", type=t_or_f, default=True) parser.add_argument("--specialized-code", type=t_or_f, default=True) parser.add_argument("--elementwise", action='store_true') parser.add_argument("-n", type=int, default=1000) parser.add_argument("--batch", type=int, default=10) args = parser.parse_args() device = 'cuda' if (torch.cuda.is_available() and args.cuda) else 'cpu' args.cuda = device == 'cuda' print("======= Benchmark with settings: ======") for key, val in vars(args).items(): print(f"{key:>18} : {val}") print("=" * 40) irreps_in1 = Irreps(args.irreps_in1 if args.irreps_in1 else args.irreps) irreps_in2 = Irreps(args.irreps_in2 if args.irreps_in2 else args.irreps) irreps_out = Irreps(args.irreps_out if args.irreps_out else args.irreps) if args.elementwise: tp = ElementwiseTensorProduct(irreps_in1, irreps_in2, _specialized_code=args.specialized_code, _optimize_einsums=args.opt_ein) if args.backward: print( "Elementwise TP has no weights, cannot backward. Setting --backward False." ) args.backward = False else: tp = FullyConnectedTensorProduct( irreps_in1, irreps_in2, irreps_out, _specialized_code=args.specialized_code, _optimize_einsums=args.opt_ein) tp = tp.to(device=device) assert len(tp.instructions) > 0, "Bad irreps, no instructions" print(f"Tensor product: {tp}") print("Instructions:") for ins in tp.instructions: print(f" {ins}") # from https://pytorch.org/docs/master/_modules/torch/utils/benchmark/utils/timer.html#Timer.timeit warmup = max(int(args.n // 100), 1) inputs = iter([(irreps_in1.randn(args.batch, -1).to(device=device), irreps_in2.randn(args.batch, -1).to(device=device)) for _ in range(args.n + warmup)]) # compile if args.jit: tp = compile(tp) print("starting...") # tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones t = Timer( stmt=("tp.zero_grad()\n" "out = tp(*next(inputs))\n" + ("out.tanh().sum().backward()\n" if args.backward else '')), globals={ 'tp': tp, 'inputs': inputs }) perloop = t.timeit(args.n) print() print(perloop)