def test_performance(self): print() runs = 1000 device = torch.device('cuda') dtype = torch.float16 s = 512 b = 32 hidden_size = 1024 epsilon = 1e-5 x = torch.randn((s * b, hidden_size), dtype=dtype, device=device) beta = torch.randn(hidden_size, dtype=dtype, device=device) gamma = torch.randn(hidden_size, dtype=dtype, device=device) dy = torch.randn_like(x) stream = torch.cuda.Stream() with torch.cuda.stream(stream): timer = GPUTimer(stream) #warmup for r in range(runs): y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5) timer.start() for r in range(runs): y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5) timer.stop() timer.sync() total_bytes_fwd = (size_in_bytes(x) + size_in_bytes(y) + size_in_bytes(gamma) + size_in_bytes(beta) + size_in_bytes(mu) + size_in_bytes(rsigma)) ms_fwd = timer.millis() / runs print('[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format( ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd)) timer.start() for r in range(runs): dx, dgamma, dbeta = fln.ln_bwd(dy, x, mu, rsigma, gamma) timer.stop() timer.sync() total_bytes_bwd = (size_in_bytes(x) + size_in_bytes(dx) + size_in_bytes(dy) + size_in_bytes(gamma) + size_in_bytes(dgamma) + size_in_bytes(dbeta) + size_in_bytes(mu) + size_in_bytes(rsigma)) ms_bwd = timer.millis() / runs print('[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format( ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd))
def benchmark_(S, B, hidden_size, itype, wtype, runs=100): epsilon = 1e-5 x = torch.randn((S * B, hidden_size), dtype=itype, device=device) beta = torch.randn(hidden_size, dtype=wtype, device=device) gamma = torch.randn(hidden_size, dtype=wtype, device=device) dz = torch.randn(x.shape, dtype=wtype, device=device) stream = torch.cuda.Stream() with torch.cuda.stream(stream): timer = GPUTimer(stream) # warmup for r in range(runs): z, mu, rsigma = fln.ln_fwd(x, gamma, beta, epsilon) timer.start() for r in range(runs): z, mu, rsigma = fln.ln_fwd(x, gamma, beta, epsilon) timer.stop() timer.sync() total_bytes_fwd = sum( [size_in_bytes(t) for t in [x, z, gamma, beta, mu, rsigma]]) ms_fwd = timer.millis() / runs print("[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec".format( ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd)) timer.start() for r in range(runs): dx, dgamma, dbeta, dbp, dgp = fln.ln_bwd(dz, x, mu, rsigma, gamma) timer.stop() timer.sync() total_bytes_bwd = sum([ size_in_bytes(t) for t in [dz, x, mu, rsigma, gamma, dx, dgamma, dbeta, dbp, dbp, dgp, dgp] ]) ms_bwd = timer.millis() / runs print("[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec".format( ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd))
def forward(ctx, x, gamma, beta, epsilon): x = x.contiguous() gamma = gamma.contiguous() beta = beta.contiguous() hidden_size = gamma.numel() xmat = x.view((-1, hidden_size)) ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon) ctx.save_for_backward(x, gamma, mu, rsigma) return ymat.view(x.shape)
def test_(S, B, hidden_size, itype, wtype, ctype=fp32): seed = 1243 torch.manual_seed(seed) torch.cuda.manual_seed(seed) otype = wtype print("========================================================") print(f"S={S} B={B} Hidden={hidden_size} {itype} {wtype}") print("--------------------------------------------------------") x = torch.randn(S * B, hidden_size, dtype=itype, device=device) gamma = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2 beta = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2 epsilon = 1e-5 x.requires_grad = True gamma.requires_grad = True beta.requires_grad = True mu_ref = x.mean(1, dtype=ctype, keepdim=True) v = torch.square(x - mu_ref).mean(1, dtype=ctype, keepdim=True) rs_ref = torch.rsqrt(v + epsilon) y_ref = rs_ref * (x.to(ctype) - mu_ref) z_ref = (gamma.unsqueeze(0) * (y_ref).to(otype) + beta.unsqueeze(0)).to(otype) mu_ref = mu_ref.flatten() rs_ref = rs_ref.flatten() dz = torch.randn_like(z_ref) # z_ref.backward(dz) # dx_ref = x.grad # dgamma_ref = gamma.grad # dbeta_ref = beta.grad dx_ref, dg_ref, db_ref = backward_(dz, x, mu_ref, rs_ref, gamma) z, mu, rs = fln.ln_fwd(x, gamma, beta, epsilon) dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, x, mu, rs, gamma) re_z, mse_z = metrics(z_ref, z) re_mu, mse_mu = metrics(mu_ref, mu) re_rs, mse_rs = metrics(rs_ref, rs) re_dx, mse_dx = metrics(dx_ref, dx) re_dg, mse_dg = metrics(dg_ref, dg) re_db, mse_db = metrics(db_ref, db) print(f" z: relerr={re_z :.4e} mse={mse_z :.4e}") print(f"mu: relerr={re_mu:.4e} mse={mse_mu:.4e}") print(f"rs: relerr={re_mu:.4e} mse={mse_mu:.4e}") print(f"dx: relerr={re_dx:.4e} mse={mse_dx:.4e}") print(f"dg: relerr={re_dg:.4e} mse={mse_dg:.4e}") print(f"db: relerr={re_db:.4e} mse={mse_db:.4e}") def check_err(x, relerr): tol = 1e-3 if x.dtype == torch.float16 else 5e-6 return relerr < tol return [ check_err(x, re) for x, re in zip( [z, mu, rs, dx, dg, db], [re_z, re_mu, re_rs, re_dx, re_dg, re_db]) ]