Example #1
0
def test_forward(T, D, B, L, fp16, weighted):
    D = D * 4
    E = int(1e4)
    bs = [
        torch.nn.EmbeddingBag(E, D, mode="sum", sparse=True).cuda()
        for _ in range(T)
    ]
    if fp16:
        bs = [b.half() for b in bs]

    xs = [torch.randint(low=0, high=E, size=(B, L)).cuda() for _ in range(T)]
    xws = [torch.randn(size=(B, L)).cuda() for _ in range(T)]

    if fp16:
        xws = [xw.half() for xw in xws]

    fs = ([b(x) for (b, x) in zip(bs, xs)] if not weighted else
          [b(x, per_sample_weights=xw) for (b, x, xw) in zip(bs, xs, xws)])

    f = torch.cat([f.view(B, 1, D) for f in fs], dim=1)

    cc = table_batched_embeddings_ops.TableBatchedEmbeddingBags(
        T, E, D, fp16=fp16).cuda()

    for t in range(T):
        cc.embedding_weights.data.view(T, E, D)[t, :, :] = bs[t].weight
    x = torch.cat([x.view(1, B, L) for x in xs], dim=0)
    xw = torch.cat([xw.view(1, B, L) for xw in xws], dim=0)

    (indices, offsets) = get_table_batched_offsets_from_dense(x)
    fc2 = (cc(indices, offsets) if not weighted else cc(
        indices, offsets,
        xw.contiguous().view(-1).cuda()))
    torch.testing.assert_allclose(f, fc2)
Example #2
0
def test_backward_sgd(T, D, B, L, fp16):
    D = D * 4
    E = int(1e4)
    bs = [
        torch.nn.EmbeddingBag(E, D, mode="sum", sparse=True).cuda()
        for _ in range(T)
    ]
    if fp16:
        bs = [b.half() for b in bs]
    xs = [
        torch.from_numpy(
            np.random.choice(range(E), size=(B, L),
                             replace=False).astype(np.int64)).cuda()
        for _ in range(T)
    ]

    def b_indices(b, x):
        (indices, offsets) = get_offsets_from_dense(x)
        return b(indices.long(), offsets.to(torch.int64))

    fs = [b_indices(b, x) for (b, x) in zip(bs, xs)]
    gos = [torch.randn_like(f) for f in fs]
    [f.backward(go) for (f, go) in zip(fs, gos)]
    # do SGD update
    lr = 0.05
    new_weights = [(b.weight - b.weight.grad * lr) for b in bs]

    f = torch.cat([f.view(B, 1, D) for f in fs], dim=1)

    cc = table_batched_embeddings_ops.TableBatchedEmbeddingBags(
        T,
        E,
        D,
        optimizer=table_batched_embeddings_ops.Optimizer.SGD,
        learning_rate=0.05,
        fp16=fp16,
    ).cuda()

    for t in range(T):
        cc.embedding_weights.data.view(T, E, D)[t, :, :] = bs[t].weight

    x = torch.cat([x.view(1, B, L) for x in xs], dim=0)
    (indices, offsets) = get_table_batched_offsets_from_dense(x)
    fc2 = cc(indices, offsets)
    fc2.backward(torch.cat([go.view(B, 1, D) for go in gos], dim=1))
    for t in range(T):
        torch.testing.assert_allclose(
            cc.embedding_weights.view(T, E, D)[t, :, :], new_weights[t])
def benchmark_forward(B, E, T, L, D, iters, fp16, managed, mixed):
    logging.basicConfig(level=logging.DEBUG)
    import torch
    import table_batched_embeddings

    np.random.seed(42)
    if mixed:
        mixed_D = [
            div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
            for _ in range(T)
        ]
        D = np.average(mixed_D)
    cc = (
        table_batched_embeddings_ops.TableBatchedEmbeddingBags(
            T,
            E,
            D,
            optimizer=table_batched_embeddings_ops.Optimizer.APPROX_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            managed=table_batched_embeddings_ops.EmbeddingLocation.DEVICE
            if not managed
            else table_batched_embeddings_ops.EmbeddingLocation.HOST_MAPPED,
            eps=0.1,
            stochastic_rounding=False,
            fp16=fp16,
        ).cuda()
        if not mixed
        else table_batched_embeddings_ops.MixedDimTableBatchedEmbeddingBags(
            [(E, d) for d in mixed_D],
            optimizer=table_batched_embeddings_ops.Optimizer.APPROX_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            managed=table_batched_embeddings_ops.EmbeddingLocation.DEVICE
            if not managed
            else table_batched_embeddings_ops.EmbeddingLocation.HOST_MAPPED,
            eps=0.1,
            stochastic_rounding=False,
            fp16=fp16,
        ).cuda()
    )

    logging.info(
        f"Embedding parameters: {cc.embedding_weights.numel() / 1.0e9:.2f}GParam"
    )

    R = False

    def w2(c):
        if not R:
            return c

        @functools.wraps(c)
        def z(w, o, x, *args):
            c(w, o, x.random_(0, E - 1), *args)

        return z

    def w3(c):
        if not R:
            return c

        @functools.wraps(c)
        def z(g, w, o, x, *args):
            c(g, w, o, x.random_(0, E - 1), *args)

        return z

    def w4(c):
        if not R:
            return c

        @functools.wraps(c)
        def z(g, w, o, a, x, *args):
            c(g, w, o, a, x.random_(0, E - 1), *args)

        return z

    def w6(c):
        if not R:
            return c

        @functools.wraps(c)
        def z(g, w, o, a, b, d, x, *args):
            c(g, w, o, a, b, d, x.random_(0, E - 1), *args)

        return z

    zs = [
        torch.tensor(np.random.zipf(a=1.2, size=(B, L))).int().cuda()
        % E
        # torch.randint(low=0, high=E - 1, size=(T, B, L)).int().cuda()
    ]

    print(
        f"Duplicate proportion: {1.0 - np.unique(zs[0].detach().cpu().numpy()).size / zs[0].detach().cpu().numpy().size}"
    )
    merged_indices = torch.stack(zs, dim=0)

    merged_indices = torch.randint(low=0, high=E - 1, size=(T, B, L)).int().cuda()

    (indices, offsets) = get_table_batched_offsets_from_dense(merged_indices)
    assert indices.shape[0] == B * T * L
    assert all(
        l == L for l in (offsets[1:] - offsets[:-1]).detach().cpu().numpy().tolist()
    )
    per_sample_weights = None
    print(indices.shape, indices.min(), indices.max(), indices)
    y0 = (
        table_batched_embeddings.forward(
            cc.embedding_weights,
            cc.table_offsets,
            indices,
            offsets,
            per_sample_weights,
            L,
            1,
            False,
        )
        if not mixed
        else table_batched_embeddings.forward_mixed_D(
            cc.embedding_weights,
            cc.table_offsets,
            cc.dim_offsets,
            cc.total_D,
            indices,
            offsets,
            per_sample_weights,
            L,
            1,
            False,
        )
    )

    for BT_block_size in [1, 2, 4, 8, 16, 32, 64, 128]:
        for shmem in [True, False]:
            y = (
                table_batched_embeddings.forward(
                    cc.embedding_weights,
                    cc.table_offsets,
                    indices,
                    offsets,
                    per_sample_weights,
                    L,
                    BT_block_size,
                    shmem,
                )
                if not mixed
                else table_batched_embeddings.forward_mixed_D(
                    cc.embedding_weights,
                    cc.table_offsets,
                    cc.dim_offsets,
                    cc.total_D,
                    indices,
                    offsets,
                    per_sample_weights,
                    L,
                    BT_block_size,
                    False,
                )
            )
            torch.testing.assert_allclose(y, y0)

    for BT_block_size in [1, 2, 4, 8, 16, 32, 64, 128]:
        for shmem in [True, False]:
            time_per_iter = (
                benchmark_torch_function(
                    iters,
                    w2(table_batched_embeddings.forward),
                    cc.embedding_weights,
                    cc.table_offsets,
                    indices,
                    offsets,
                    per_sample_weights,
                    L,
                    BT_block_size,
                    shmem,
                )
                if not mixed
                else benchmark_torch_function(
                    iters,
                    w4(table_batched_embeddings.forward_mixed_D),
                    cc.embedding_weights,
                    cc.table_offsets,
                    cc.dim_offsets,
                    cc.total_D,
                    indices,
                    offsets,
                    per_sample_weights,
                    L,
                    BT_block_size,
                    shmem,
                )
            )
            logging.info(
                f"Forward, B: {B} {(BT_block_size, shmem)}, E: {E}, T: {T}, D: {D}, L: {L}, BW: {(2 if fp16 else 4) * B * T * L * D / time_per_iter / 1.0e9: .2f}GB/s, T: {time_per_iter * 1.0e6:.0f}us"
            )

    go = torch.randn_like(y0)

    learning_rate = 0.05
    eps = 0.01
    # for BT_block_size in [1, 2, 4, 8, 16, 32]:
    #     for shmem in [True, False]:
    #         time_per_iter = benchmark_torch_function(
    #             iters,
    #             w3(table_batched_embeddings.backward_sgd),
    #             go,
    #             cc.embedding_weights,
    #             cc.table_offsets,
    #             indices,
    #             offsets,
    #             learning_rate,
    #             L,
    #             BT_block_size,
    #             shmem,
    #         )

    #         logging.info(
    #             f"Backward-SGD, B: {B} {(BT_block_size, shmem)}, E: {E}, T: {T}, D: {D}, L: {L}, BW: {2 * (2 if fp16 else 4) * B * T * L * D / time_per_iter / 1.0e9: .2f}GB/s, T: {time_per_iter * 1.0e6:.0f}us"
    #         )
    for BT_block_size in [
        1,
        2,
        4,
        8,
        16,
        32,
    ]:
        for exact in [0, 1]:
            for stochastic in [0, 1] if fp16 else [0]:
                if not exact:
                    time_per_iter = (
                        benchmark_torch_function(
                            iters,
                            w3(table_batched_embeddings.backward_approx_adagrad),
                            go,
                            cc.embedding_weights,
                            cc.table_offsets,
                            indices,
                            offsets,
                            per_sample_weights,
                            cc.optimizer_state,
                            learning_rate,
                            eps,
                            L,
                            stochastic,
                            BT_block_size,
                        )
                        if not mixed
                        else benchmark_torch_function(
                            iters,
                            w6(
                                table_batched_embeddings.backward_approx_adagrad_mixed_D
                            ),
                            go,
                            cc.embedding_weights,
                            cc.table_offsets,
                            cc.table_dim_offsets,
                            cc.dim_offsets,
                            cc.total_D,
                            indices,
                            offsets,
                            per_sample_weights,
                            cc.optimizer_state,
                            learning_rate,
                            eps,
                            L,
                            stochastic,
                            BT_block_size,
                        )
                    )
                else:
                    time_per_iter = (
                        benchmark_torch_function(
                            iters,
                            w3(table_batched_embeddings.backward_exact_adagrad),
                            go,
                            cc.embedding_weights,
                            cc.table_offsets,
                            indices,
                            offsets,
                            per_sample_weights,
                            cc.optimizer_state,
                            learning_rate,
                            eps,
                            stochastic,
                            BT_block_size,
                        )
                        if not mixed
                        else benchmark_torch_function(
                            iters,
                            w6(table_batched_embeddings.backward_exact_adagrad_mixed_D),
                            go,
                            cc.embedding_weights,
                            cc.table_offsets,
                            cc.table_dim_offsets,
                            cc.dim_offsets,
                            cc.total_D,
                            indices,
                            offsets,
                            per_sample_weights,
                            cc.optimizer_state,
                            learning_rate,
                            eps,
                            stochastic,
                            BT_block_size,
                        )
                    )

                logging.info(
                    f"Backward-ADAGRAD-{'nonstochastic' if not stochastic else 'stochastic'}-{'EXACT' if exact else 'APPROX'}-{'R' if R else 'NR'}, B: {B} ({BT_block_size}), E: {E}, T: {T}, D: {D}, L: {L}, BW: {2 * (2 if fp16 else 4) * B * T * L * D / time_per_iter / 1.0e9: .2f}GB/s, T: {time_per_iter * 1.0e6:.0f}us"
                )
Example #4
0
def benchmark_embedding_lookup(B, E, T, L, D, BT_block_size, iters, warmup_iters, backward, shmem, sgd, fp16, managed, mixed):
    Es = [int(x) for x in E.split('-')]
    if len(Es) == 1:
        Es = Es * T
    assert len(Es) == T

    if mixed:
        mixed_D = [
            div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
            for _ in range(T)
        ]
        D = np.average(mixed_D)
    cc = (
        table_batched_embeddings_ops.TableBatchedEmbeddingBags(
            T,
            Es,
            D,
            optimizer=table_batched_embeddings_ops.Optimizer.APPROX_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            managed=table_batched_embeddings_ops.EmbeddingLocation.DEVICE
            if not managed
            else table_batched_embeddings_ops.EmbeddingLocation.HOST_MAPPED,
            eps=0.1,
            stochastic_rounding=False,
            fp16=fp16,
        ).cuda()
        if not mixed
        else table_batched_embeddings_ops.MixedDimTableBatchedEmbeddingBags(
            [(Es, d) for d in mixed_D],
            optimizer=table_batched_embeddings_ops.Optimizer.APPROX_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            managed=table_batched_embeddings_ops.EmbeddingLocation.DEVICE
            if not managed
            else table_batched_embeddings_ops.EmbeddingLocation.HOST_MAPPED,
            eps=0.1,
            stochastic_rounding=False,
            fp16=fp16,
        ).cuda()
    )

    logging.info(
        f"Embedding parameters: {cc.embedding_weights.numel() / 1.0e9:.2f}GParam"
    )

    R = False

    def w2(c):
        if not R:
            return c

        @functools.wraps(c)
        def z(w, o, x, *args):
            c(w, o, x.random_(0, E - 1), *args)

        return z

    def w3(c):
        if not R:
            return c

        @functools.wraps(c)
        def z(g, w, o, x, *args):
            c(g, w, o, x.random_(0, E - 1), *args)

        return z

    def w4(c):
        if not R:
            return c

        @functools.wraps(c)
        def z(g, w, o, a, x, *args):
            c(g, w, o, a, x.random_(0, E - 1), *args)

        return z

    def w6(c):
        if not R:
            return c

        @functools.wraps(c)
        def z(g, w, o, a, b, d, x, *args):
            c(g, w, o, a, b, d, x.random_(0, E - 1), *args)

        return z

    idxs = []
    for x in range(T):
        idxs.append(torch.randint(low=0, high=Es[x] - 1, size=(B, L)).int().cuda())
    merged_indices = torch.stack(idxs, dim=0)

    (indices, offsets) = get_table_batched_offsets_from_dense(merged_indices)

    assert indices.shape[0] == B * T * L
    assert all(
        l == L for l in (offsets[1:] - offsets[:-1]).detach().cpu().numpy().tolist()
    )
    per_sample_weights = None
    stochastic = False # TODO: Fix this
    exact = 1
    y0 = (
        table_batched_embeddings.forward(
            cc.embedding_weights,
            cc.table_offsets,
            indices,
            offsets,
            per_sample_weights,
            L,
            1,
            shmem,
        )
        if not mixed
        else table_batched_embeddings.forward_mixed_D(
            cc.embedding_weights,
            cc.table_offsets,
            cc.dim_offsets,
            cc.total_D,
            indices,
            offsets,
            per_sample_weights,
            L,
            1,
            shmem,
        )
    )

    y = (
        table_batched_embeddings.forward(
            cc.embedding_weights,
            cc.table_offsets,
            indices,
            offsets,
            per_sample_weights,
            L,
            BT_block_size,
            shmem,
        )
        if not mixed
        else table_batched_embeddings.forward_mixed_D(
            cc.embedding_weights,
            cc.table_offsets,
            cc.dim_offsets,
            cc.total_D,
            indices,
            offsets,
            per_sample_weights,
            L,
            BT_block_size,
            False,
        )
    )
    torch.testing.assert_allclose(y, y0)

    if not backward:
        time_per_iter = (
            benchmark_torch_function(
                iters,
                warmup_iters,
                w2(table_batched_embeddings.forward),
                cc.embedding_weights,
                cc.table_offsets,
                indices,
                offsets,
                per_sample_weights,
                L,
                BT_block_size,
                shmem,
            )
            if not mixed
            else benchmark_torch_function(
                iters,
                warmup_iters,
                w4(table_batched_embeddings.forward_mixed_D),
                cc.embedding_weights,
                cc.table_offsets,
                cc.dim_offsets,
                cc.total_D,
                indices,
                offsets,
                per_sample_weights,
                L,
                BT_block_size,
                shmem,
            )
        )
        logging.info(
            f"Embedding Lookup Forward, B: {B} {(BT_block_size, shmem)}, E: {E}, T: {T}, D: {D}, L: {L}, BW: {(2 if fp16 else 4) * B * T * L * D / time_per_iter / 1.0e9: .2f}GB/s, Time: {time_per_iter * 1.0e6:.0f}us"
        )

    else: # backward
        go = torch.randn_like(y0)

        learning_rate = 0.05
        eps = 0.01

        if sgd:
            time_per_iter = benchmark_torch_function(
                iters,
                warmup_iters,
                w3(table_batched_embeddings.backward_sgd),
                go,
                cc.embedding_weights,
                cc.table_offsets,
                indices,
                offsets,
                learning_rate,
                L,
                BT_block_size,
                shmem,
            )

            logging.info(
                f"Embedding Lookup Backward-SGD, B: {B} {(BT_block_size, shmem)}, E: {E}, T: {T}, D: {D}, L: {L}, BW: {2 * (2 if fp16 else 4) * B * T * L * D / time_per_iter / 1.0e9: .2f}GB/s, Time: {time_per_iter * 1.0e6:.0f}us"
            )
        else: # adagrad
            if not exact:
                time_per_iter = (
                    benchmark_torch_function(
                        iters,
                        warmup_iters,
                        w3(table_batched_embeddings.backward_approx_adagrad),
                        go,
                        cc.embedding_weights,
                        cc.table_offsets,
                        indices,
                        offsets,
                        per_sample_weights,
                        cc.optimizer_state,
                        learning_rate,
                        eps,
                        L,
                        stochastic,
                        BT_block_size,
                    )
                    if not mixed
                    else benchmark_torch_function(
                        iters,
                        warmup_iters,
                        w6(
                            table_batched_embeddings.backward_approx_adagrad_mixed_D
                        ),
                        go,
                        cc.embedding_weights,
                        cc.table_offsets,
                        cc.table_dim_offsets,
                        cc.dim_offsets,
                        cc.total_D,
                        indices,
                        offsets,
                        per_sample_weights,
                        cc.optimizer_state,
                        learning_rate,
                        eps,
                        L,
                        stochastic,
                        BT_block_size,
                    )
                )
            else:
                time_per_iter = (
                    benchmark_torch_function(
                        iters,
                        warmup_iters,
                        w3(table_batched_embeddings.backward_exact_adagrad),
                        go,
                        cc.embedding_weights,
                        cc.table_offsets,
                        indices,
                        offsets,
                        per_sample_weights,
                        cc.optimizer_state,
                        learning_rate,
                        eps,
                        stochastic,
                        BT_block_size,
                    )
                    if not mixed
                    else benchmark_torch_function(
                        iters,
                        warmup_iters,
                        w6(table_batched_embeddings.backward_exact_adagrad_mixed_D),
                        go,
                        cc.embedding_weights,
                        cc.table_offsets,
                        cc.table_dim_offsets,
                        cc.dim_offsets,
                        cc.total_D,
                        indices,
                        offsets,
                        per_sample_weights,
                        cc.optimizer_state,
                        learning_rate,
                        eps,
                        stochastic,
                        BT_block_size,
                    )
                )

            logging.info(
                f"Embedding Lookup Backward-ADAGRAD-{'nonstochastic' if not stochastic else 'stochastic'}-{'EXACT' if exact else 'APPROX'}-{'R' if R else 'NR'}, B: {B} ({BT_block_size}), E: {E}, T: {T}, D: {D}, L: {L}, BW: {2 * (2 if fp16 else 4) * B * T * L * D / time_per_iter / 1.0e9: .2f}GB/s, Time: {time_per_iter * 1.0e6:.0f}us"
            )
Example #5
0
def test_backward_adagrad(T, D, B, L, D_gradcheck, fp16, stochastic_rounding,
                          weighted, exact):
    E = int(1e4) if not exact else int(1e2)
    D_gradcheck = D_gradcheck * 4
    weighted = False if exact else weighted

    D = D * 4
    bs = [
        torch.nn.EmbeddingBag(E, D, mode="sum", sparse=True).cuda()
        for _ in range(T)
    ]
    if fp16:
        bs = [b.half() for b in bs]

    xs = [
        torch.from_numpy(
            np.random.choice(range(E),
                             size=(B, L),
                             replace=False if not exact else True).astype(
                                 np.int64)).cuda() for _ in range(T)
    ]
    xws = [torch.randn(size=(B, L)).cuda() for _ in range(T)]

    if fp16:
        xws = [xw.half() for xw in xws]

    def b_indices(b, x):
        (indices, offsets) = get_offsets_from_dense(x)
        return b(indices.long(), offsets.to(torch.int64))

    fs = ([b_indices(b, x) for (b, x) in zip(bs, xs)] if not weighted else
          [b(x, per_sample_weights=xw) for (b, x, xw) in zip(bs, xs, xws)])
    gos = [torch.randn_like(f) for f in fs]
    [f.backward(go) for (f, go) in zip(fs, gos)]
    # do SGD update
    lr = 0.5
    eps = 0.2

    new_weights = [b.weight for b in bs]
    f = torch.cat([f.view(B, 1, D) for f in fs], dim=1)
    cc = table_batched_embeddings_ops.TableBatchedEmbeddingBags(
        T,
        E,
        D,
        optimizer=table_batched_embeddings_ops.Optimizer.APPROX_ROWWISE_ADAGRAD
        if not exact else
        table_batched_embeddings_ops.Optimizer.EXACT_ROWWISE_ADAGRAD,
        learning_rate=lr,
        eps=eps,
        fp16=fp16,
        stochastic_rounding=stochastic_rounding,
    ).cuda()

    for t in range(T):
        cc.embedding_weights.data.view(T, E, D)[t, :, :] = bs[t].weight

    x = torch.cat([x.view(1, B, L) for x in xs], dim=0)
    xw = torch.cat([xw.view(1, B, L) for xw in xws], dim=0)

    (indices, offsets) = get_table_batched_offsets_from_dense(x)
    fc2 = (cc(indices, offsets) if not weighted else cc(
        indices, offsets,
        xw.contiguous().view(-1).cuda()))
    fc2.backward(torch.cat([go.view(B, 1, D) for go in gos], dim=1))

    # optimizer state is sum_square_grads.
    for t in range(T):
        torch.testing.assert_allclose(
            cc.optimizer_state.view(T, E)[t],
            bs[t].weight.grad.float().to_dense().pow(2).sum(dim=1),
            atol=1.0e-3 if fp16 else 1.0e-4,
            rtol=1.0e-3 if fp16 else 1.0e-4,
        )

    for t in range(T):
        torch.testing.assert_allclose(
            cc.embedding_weights.view(T, E, D)[t, :, :].float(),
            torch.addcdiv(
                bs[t].weight.float(),
                value=-lr,
                tensor1=bs[t].weight.grad.float().to_dense(),
                tensor2=cc.optimizer_state.view(
                    T, E)[t, :].sqrt_().add_(eps).view(E, 1),
            ),
            atol=1.0e-3 if fp16 else 1.0e-4,
            rtol=1.0e-3 if fp16 else 1.0e-4,
        )

    if weighted:
        cc = (table_batched_embeddings_ops.TableBatchedEmbeddingBags(
            T,
            E,
            D_gradcheck,
            optimizer=table_batched_embeddings_ops.Optimizer.
            APPROX_ROWWISE_ADAGRAD,
            learning_rate=0.0,
            eps=eps,
            fp16=fp16,
            stochastic_rounding=stochastic_rounding,
        ).cuda().double())
        per_sample_weights = xw.contiguous().view(-1).cuda().double()
        per_sample_weights.requires_grad = True
        indices.requires_grad = False
        offsets.requires_grad = False
        cc.embedding_weights.requires_grad = False
        torch.autograd.gradcheck(cc, (indices, offsets, per_sample_weights))