예제 #1
0
    def backward(ctx, dy):
        #assert dy.is_contiguous()
        dy = dy.contiguous() # this happens!
        x, gamma, mu, rsigma = ctx.saved_tensors

        hidden_size = gamma.numel()
        xmat = x.view((-1, hidden_size))
        dymat = dy.view(xmat.shape)
        dxmat, dgamma, dbeta = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma)
        dx = dxmat.view(x.shape)
        return dx, dgamma, dbeta, None
예제 #2
0
    def test_performance(self):
        print()
        runs = 1000
        device = torch.device('cuda')
        dtype = torch.float16
        s = 512
        b = 32
        hidden_size = 1024
        epsilon = 1e-5

        x = torch.randn((s * b, hidden_size), dtype=dtype, device=device)
        beta = torch.randn(hidden_size, dtype=dtype, device=device)
        gamma = torch.randn(hidden_size, dtype=dtype, device=device)
        dy = torch.randn_like(x)

        stream = torch.cuda.Stream()
        with torch.cuda.stream(stream):

            timer = GPUTimer(stream)

            #warmup
            for r in range(runs):
                y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5)

            timer.start()
            for r in range(runs):
                y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5)
            timer.stop()
            timer.sync()

            total_bytes_fwd = (size_in_bytes(x) + size_in_bytes(y) +
                               size_in_bytes(gamma) + size_in_bytes(beta) +
                               size_in_bytes(mu) + size_in_bytes(rsigma))

            ms_fwd = timer.millis() / runs
            print('[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format(
                ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd))

            timer.start()
            for r in range(runs):
                dx, dgamma, dbeta = fln.ln_bwd(dy, x, mu, rsigma, gamma)
            timer.stop()
            timer.sync()

            total_bytes_bwd = (size_in_bytes(x) + size_in_bytes(dx) +
                               size_in_bytes(dy) + size_in_bytes(gamma) +
                               size_in_bytes(dgamma) + size_in_bytes(dbeta) +
                               size_in_bytes(mu) + size_in_bytes(rsigma))

            ms_bwd = timer.millis() / runs
            print('[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format(
                ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd))
예제 #3
0
def benchmark_(S, B, hidden_size, itype, wtype, runs=100):
    epsilon = 1e-5

    x = torch.randn((S * B, hidden_size), dtype=itype, device=device)
    beta = torch.randn(hidden_size, dtype=wtype, device=device)
    gamma = torch.randn(hidden_size, dtype=wtype, device=device)
    dz = torch.randn(x.shape, dtype=wtype, device=device)

    stream = torch.cuda.Stream()
    with torch.cuda.stream(stream):

        timer = GPUTimer(stream)

        # warmup
        for r in range(runs):
            z, mu, rsigma = fln.ln_fwd(x, gamma, beta, epsilon)

        timer.start()
        for r in range(runs):
            z, mu, rsigma = fln.ln_fwd(x, gamma, beta, epsilon)
        timer.stop()
        timer.sync()

        total_bytes_fwd = sum(
            [size_in_bytes(t) for t in [x, z, gamma, beta, mu, rsigma]])

        ms_fwd = timer.millis() / runs

        print("[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec".format(
            ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd))

        timer.start()
        for r in range(runs):
            dx, dgamma, dbeta, dbp, dgp = fln.ln_bwd(dz, x, mu, rsigma, gamma)
        timer.stop()
        timer.sync()

        total_bytes_bwd = sum([
            size_in_bytes(t) for t in
            [dz, x, mu, rsigma, gamma, dx, dgamma, dbeta, dbp, dbp, dgp, dgp]
        ])

        ms_bwd = timer.millis() / runs

        print("[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec".format(
            ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd))
예제 #4
0
def test_(S, B, hidden_size, itype, wtype, ctype=fp32):

    seed = 1243
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    otype = wtype
    print("========================================================")
    print(f"S={S} B={B} Hidden={hidden_size} {itype} {wtype}")
    print("--------------------------------------------------------")

    x = torch.randn(S * B, hidden_size, dtype=itype, device=device)
    gamma = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2
    beta = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2
    epsilon = 1e-5

    x.requires_grad = True
    gamma.requires_grad = True
    beta.requires_grad = True

    mu_ref = x.mean(1, dtype=ctype, keepdim=True)
    v = torch.square(x - mu_ref).mean(1, dtype=ctype, keepdim=True)
    rs_ref = torch.rsqrt(v + epsilon)
    y_ref = rs_ref * (x.to(ctype) - mu_ref)
    z_ref = (gamma.unsqueeze(0) * (y_ref).to(otype) +
             beta.unsqueeze(0)).to(otype)

    mu_ref = mu_ref.flatten()
    rs_ref = rs_ref.flatten()

    dz = torch.randn_like(z_ref)

    # z_ref.backward(dz)
    # dx_ref = x.grad
    # dgamma_ref = gamma.grad
    # dbeta_ref = beta.grad

    dx_ref, dg_ref, db_ref = backward_(dz, x, mu_ref, rs_ref, gamma)

    z, mu, rs = fln.ln_fwd(x, gamma, beta, epsilon)
    dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, x, mu, rs, gamma)

    re_z, mse_z = metrics(z_ref, z)
    re_mu, mse_mu = metrics(mu_ref, mu)
    re_rs, mse_rs = metrics(rs_ref, rs)

    re_dx, mse_dx = metrics(dx_ref, dx)
    re_dg, mse_dg = metrics(dg_ref, dg)
    re_db, mse_db = metrics(db_ref, db)

    print(f" z: relerr={re_z :.4e} mse={mse_z :.4e}")
    print(f"mu: relerr={re_mu:.4e} mse={mse_mu:.4e}")
    print(f"rs: relerr={re_mu:.4e} mse={mse_mu:.4e}")

    print(f"dx: relerr={re_dx:.4e} mse={mse_dx:.4e}")
    print(f"dg: relerr={re_dg:.4e} mse={mse_dg:.4e}")
    print(f"db: relerr={re_db:.4e} mse={mse_db:.4e}")

    def check_err(x, relerr):
        tol = 1e-3 if x.dtype == torch.float16 else 5e-6
        return relerr < tol

    return [
        check_err(x, re) for x, re in zip(
            [z, mu, rs, dx, dg, db], [re_z, re_mu, re_rs, re_dx, re_dg, re_db])
    ]