def test_cache_pipeline(self, T, D, B, log_E, L, mixed, cache_algorithm):
        iters = 3
        E = int(10**log_E)
        D = D * 4
        if not mixed:
            Ds = [D] * T
            Es = [E] * T
        else:
            Ds = [
                div_round_up(
                    np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
                for _ in range(T)
            ]
            Es = [
                np.random.randint(low=int(0.5 * E), high=int(2.0 * E))
                for _ in range(T)
            ]
        managed = [
            split_table_batched_embeddings_ops.EmbeddingLocation.
            MANAGED_CACHING
        ] * T
        if mixed:
            average_D = sum(Ds) // T
            for t, d in enumerate(Ds):
                managed[t] = (
                    split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
                    if d < average_D else managed[t])
        cc_ref = (
            split_table_batched_embeddings_ops.
            SplitTableBatchedEmbeddingBagsCodegen([(
                E,
                D,
                split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
                split_table_batched_embeddings_ops.ComputeDevice.CUDA,
            ) for (E, D) in zip(Es, Ds)], ))
        cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
            [(E, D, managed,
              split_table_batched_embeddings_ops.ComputeDevice.CUDA)
             for (E, D) in zip(Es, Ds)],
            cache_algorithm=cache_algorithm,
        )
        for t in range(T):
            assert (cc.split_embedding_weights()[t].size() ==
                    cc_ref.split_embedding_weights()[t].size())
            cc.split_embedding_weights()[t].data.copy_(
                cc_ref.split_embedding_weights()[t])

        requests = generate_requests(iters, B, T, L, min(Es), reuse=0.1)
        grad_output = torch.randn(B, sum(Ds)).cuda()

        for indices, offsets, _ in requests:
            output = cc(indices, offsets)
            output_ref = cc_ref(indices, offsets)
            torch.testing.assert_allclose(output, output_ref)
            output.backward(grad_output)
            output_ref.backward(grad_output)
        cc.flush()
        for t in range(T):
            torch.testing.assert_allclose(cc.split_embedding_weights()[t],
                                          cc_ref.split_embedding_weights()[t])
Example #2
0
    def test_backward_adagrad(  # noqa C901
        self,
        T,
        D,
        B,
        log_E,
        L,
        D_gradcheck,
        fp16,
        stochastic_rounding,
        weighted,
        row_wise,
        exact,
        mixed,
        use_cache,
        cache_algorithm,
        pooling_mode,
        use_cpu,
    ):
        # NOTE: cache is not applicable to CPU version.
        assume(not use_cpu or not use_cache)

        # NOTE: torch.autograd.gradcheck() is too time-consuming for CPU version
        #       so we have to limit (T * B * L * D)!
        assume(not use_cpu or T * B * L * D <= 1024)

        assume(
            pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM
            or not weighted
        )
        mode = (
            "sum"
            if pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM
            else "mean"
        )

        # stochastic rounding only implemented for rowwise
        assume(not stochastic_rounding or row_wise)
        # exact only implemented for rowwise non-weighted
        assume(not exact or (row_wise and not weighted))
        # need unique indices for non-exact tests
        assume(exact or int(10 ** log_E) > int(2.1 * B * L))
        # only row-wise supports caching
        assume(row_wise or not use_cache)

        E = int(10 ** log_E)
        if use_cpu:
            D = (D + 15) // 16 * 4
        else:
            D = D * 4
        if not mixed:
            Ds = [D] * T
            Es = [E] * T
        else:
            Ds = [
                div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
                for _ in range(T)
            ]
            Es = [
                np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T)
            ]
        compute_device = split_table_batched_embeddings_ops.ComputeDevice.CUDA
        if use_cpu:
            managed = [
                split_table_batched_embeddings_ops.EmbeddingLocation.HOST
            ] * T
            compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU
        elif use_cache:
            managed = [
                split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED_CACHING
            ] * T
            if mixed:
                average_D = sum(Ds) // T
                for t, d in enumerate(Ds):
                    managed[t] = (
                        split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
                        if d < average_D
                        else managed[t]
                    )
        else:
            managed = [
                np.random.choice(
                    [
                        split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
                        split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
                    ]
                )
                for _ in range(T)
            ]
        bs = [
            to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu)
            for (E, D) in zip(Es, Ds)
        ]

        if fp16 and not use_cpu:
            # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16.
            bs = [b.half() for b in bs]

        feature_table_map = list(range(T))
        if exact:
            # autograd with shared embedding only works for exact
            table_to_replicate = T // 2
            bs.insert(table_to_replicate, bs[table_to_replicate])
            feature_table_map.insert(table_to_replicate, table_to_replicate)

        xs = [
            to_device(torch.from_numpy(
                np.random.choice(range(Es[t]), size=(B, L), replace=exact).astype(
                    np.int64
                )
            ), use_cpu)
            for t in feature_table_map
        ]
        xws = [to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(len(xs))]
        xws_acc_type = copy.deepcopy(xws)

        if fp16 and not use_cpu:
            # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16.
            xws = [xw.half() for xw in xws]

        fs = (
            [b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)]
            if not weighted
            else [
                b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu)
                for (b, x, xw) in zip(bs, xs, xws)
            ]
        )
        gos = [torch.randn_like(f) for f in fs]
        [f.backward(go) for (f, go) in zip(fs, gos)]
        # do SGD update
        lr = 0.5
        eps = 0.2

        cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
            [(E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)],
            feature_table_map=feature_table_map,
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD
            if row_wise
            else OptimType.EXACT_ADAGRAD,
            learning_rate=lr,
            eps=eps,
            fp16=fp16,
            stochastic_rounding=stochastic_rounding,
            pooling_mode=pooling_mode,
        )

        if exact:
            del bs[table_to_replicate]
        for t in range(T):
            cc.split_embedding_weights()[t].data.copy_(bs[t].weight)

        x = torch.cat([x.view(1, B, L) for x in xs], dim=0)
        xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0)

        (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu)
        fc2 = (
            cc(indices, offsets)
            if not weighted
            else cc(indices, offsets, to_device(xw.contiguous().view(-1), use_cpu))
        )
        fc2.backward(torch.cat([go.view(B, -1) for go in gos], dim=1))
        cc.flush()

        split_optimizer_states = [s for (s,) in cc.split_optimizer_states()]
        for t in range(T):
            ref_optimizer_state = bs[t].weight.grad.float().to_dense().pow(2)
            torch.testing.assert_allclose(
                split_optimizer_states[t].float(),
                ref_optimizer_state.mean(dim=1) if row_wise else ref_optimizer_state,
                atol=5.0e-3 if fp16 else 1.0e-4,
                rtol=5.0e-3 if fp16 else 1.0e-4,
            )

        for t in range(T):
            # optimizer_state = squares (no row-wise) or sum squares (row-wise)
            torch.testing.assert_allclose(
                cc.split_embedding_weights()[t].float(),
                torch.addcdiv(
                    bs[t].weight.float(),
                    value=-lr,
                    tensor1=bs[t].weight.grad.float().to_dense(),
                    tensor2=split_optimizer_states[t]
                    .float()
                    .sqrt_()
                    .add_(eps)
                    .view(Es[t], 1 if row_wise else Ds[t]),
                ),
                atol=5.0e-3 if fp16 else 1.0e-4,
                rtol=5.0e-3 if fp16 else 1.0e-4,
            )
        if use_cpu:
            D_gradcheck = (D_gradcheck + 15) // 16 * 4
        else:
            D_gradcheck = D_gradcheck * 4
        cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
            [(E, D_gradcheck, M, compute_device) for (E, M) in zip(Es, managed)],
            feature_table_map=feature_table_map,
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD
            if row_wise
            else OptimType.EXACT_ADAGRAD,
            learning_rate=0.0,
            eps=eps,
            fp16=fp16,
            stochastic_rounding=stochastic_rounding,
            # NOTE: only SUM pooling can work with per_sample_weights!
            pooling_mode=split_table_batched_embeddings_ops.PoolingMode.SUM,
        )
        if use_cpu:
            # NOTE: GPU version of SplitTableBatchedEmbeddingBagsCodegen doesn't support double.
            cc = cc.double()

        per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu)
        if use_cpu:
            per_sample_weights = per_sample_weights.double()
        per_sample_weights.requires_grad = True
        indices.requires_grad = False
        offsets.requires_grad = False
        for param in cc.parameters():
            param.requires_grad = False
        torch.autograd.gradcheck(cc, (indices, offsets, per_sample_weights))

        per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu)
        if use_cpu:
            per_sample_weights = per_sample_weights.double()
        per_sample_weights.requires_grad = True
        indices.requires_grad = False
        offsets.requires_grad = False
        for param in cc.parameters():
            param.requires_grad = False
        y = cc(indices, offsets, per_sample_weights)
        y.sum().backward()
        indice_weight_grad_all = per_sample_weights.grad.clone().cpu()
        T_ = len(xws)
        feature_requires_grad = to_device(
            torch.tensor(np.random.choice([0, 1], replace=True, size=(T_,))).int(), use_cpu
        )
        per_sample_weights = per_sample_weights.detach().clone()
        per_sample_weights.requires_grad = True
        y = cc(
            indices,
            offsets,
            per_sample_weights,
            feature_requires_grad=feature_requires_grad,
        )
        y.sum().backward()
        indice_weight_grad_mask = per_sample_weights.grad.clone().cpu()
        for t in range(T_):
            if feature_requires_grad[t]:
                torch.testing.assert_allclose(
                    indice_weight_grad_mask.view(T_, B, L)[t],
                    indice_weight_grad_all.view(T_, B, L)[t],
                )
            else:
                torch.testing.assert_allclose(
                    indice_weight_grad_mask.view(T_, B, L)[t],
                    torch.zeros_like(indice_weight_grad_mask.view(T_, B, L)[t]),
                )
Example #3
0
    def test_backward_sgd(  # noqa C901
        self,
        T,
        D,
        B,
        log_E,
        L,
        fp16,
        weighted,
        exact,
        mixed,
        use_cache,
        cache_algorithm,
        long_segments,
        pooling_mode,
        use_cpu,
    ):
        # NOTE: cache is not applicable to CPU version.
        assume(not use_cpu or not use_cache)
        # NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
        assume(not use_cpu or T * B * L * D <= 2048)

        assume(
            pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM
            or not weighted
        )
        mode = (
            "sum"
            if pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM
            else "mean"
        )

        # only non-exact supports caching
        assume(not exact or not use_cache)
        E = int(10 ** log_E)
        if use_cpu:
            D = (D + 15) // 16 * 4
        else:
            D = D * 4
        if not mixed:
            Ds = [D] * T
            Es = [E] * T
        else:
            Ds = [
                div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
                for _ in range(T)
            ]
            Es = [
                np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T)
            ]
        compute_device = split_table_batched_embeddings_ops.ComputeDevice.CUDA
        if use_cpu:
            managed = [
                split_table_batched_embeddings_ops.EmbeddingLocation.HOST
            ] * T
            compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU
        elif use_cache:
            managed = [
                split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED_CACHING
            ] * T
            if mixed:
                average_D = sum(Ds) // T
                for t, d in enumerate(Ds):
                    managed[t] = (
                        split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
                        if d < average_D
                        else managed[t]
                    )
        else:
            managed = [
                np.random.choice(
                    [
                        split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
                        split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
                    ]
                )
                for _ in range(T)
            ]
        bs = [
            to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu)
            for (E, D) in zip(Es, Ds)
        ]

        if fp16 and not use_cpu:
            # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16.
            bs = [b.half() for b in bs]

        feature_table_map = list(range(T))
        table_to_replicate = T // 2
        bs.insert(table_to_replicate, bs[table_to_replicate])
        feature_table_map.insert(table_to_replicate, table_to_replicate)

        xs = [
            to_device(torch.from_numpy(
                np.random.choice(range(Es[t]), size=(B, L), replace=True).astype(
                    np.int64
                )
            ), use_cpu)
            for t in feature_table_map
        ]

        if long_segments and L > 0:
            for x in xs:
                x[:, 0] = 0

        xws = [to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(len(xs))]
        xws_acc_type = copy.deepcopy(xws)

        if fp16 and not use_cpu:
            # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16.
            xws = [xw.half() for xw in xws]

        fs = (
            [b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)]
            if not weighted
            else [
                b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu)
                for (b, x, xw) in zip(bs, xs, xws)
            ]
        )
        gos = [torch.randn_like(f) for f in fs]
        [f.backward(go) for (f, go) in zip(fs, gos)]
        # do SGD update
        lr = 0.05
        del bs[table_to_replicate]
        new_weights = [(b.weight - b.weight.grad * lr) for b in bs]

        cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
            [(E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)],
            optimizer=OptimType.EXACT_SGD,
            feature_table_map=feature_table_map,
            learning_rate=0.05,
            fp16=fp16,
            cache_algorithm=cache_algorithm,
            pooling_mode=pooling_mode,
        )

        for t in range(T):
            cc.split_embedding_weights()[t].data.copy_(bs[t].weight)

        x = torch.cat([x.view(1, B, L) for x in xs], dim=0)
        xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0)

        (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu)
        fc2 = (
            cc(indices, offsets)
            if not weighted
            else cc(indices, offsets, to_device(xw.contiguous().view(-1), use_cpu))
        )
        goc = torch.cat([go.view(B, -1) for go in gos], dim=1).contiguous()
        fc2.backward(goc)
        if use_cache:
            cc.flush()
        for t in range(T):
            torch.testing.assert_allclose(
                cc.split_embedding_weights()[t],
                new_weights[t].half() if fp16 and use_cpu else new_weights[t],
                atol=(1.0e-2 if long_segments else 5.0e-3) if fp16 else 1.0e-5,
                rtol=(1.0e-2 if long_segments else 5.0e-3) if fp16 else 1.0e-5,
            )
Example #4
0
    def test_forward(
        self,
        T,
        D,
        B,
        log_E,
        L,
        fp16,
        weighted,
        mixed,
        use_cache,
        cache_algorithm,
        pooling_mode,
        use_cpu,
    ):
        # NOTE: cache is not applicable to CPU version.
        assume(not use_cpu or not use_cache)
        # NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
        assume(not use_cpu or T * B * L * D <= 2048)

        assume(
            pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM
            or not weighted
        )
        mode = (
            "sum"
            if pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM
            else "mean"
        )

        E = int(10 ** log_E)
        if use_cpu:
            D = (D + 15) // 16 * 4
        else:
            D = D * 4
        if not mixed:
            Ds = [D] * T
            Es = [int(1e4)] * T
        else:
            Ds = [
                div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
                for _ in range(T)
            ]
            Es = [
                np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T)
            ]
        compute_device = split_table_batched_embeddings_ops.ComputeDevice.CUDA
        if use_cpu:
            managed = [
                split_table_batched_embeddings_ops.EmbeddingLocation.HOST
            ] * T
            compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU
        elif use_cache:
            managed = [
                split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED_CACHING
            ] * T
            if mixed:
                average_D = sum(Ds) // T
                for t, d in enumerate(Ds):
                    managed[t] = (
                        split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
                        if d < average_D
                        else managed[t]
                    )
        else:
            managed = [
                np.random.choice(
                    [
                        split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
                        split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
                    ]
                )
                for _ in range(T)
            ]
        bs = [
            to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu)
            for (E, D) in zip(Es, Ds)
        ]
        if fp16 and not use_cpu:
            # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16.
            bs = [b.half() for b in bs]

        xs = [to_device(torch.randint(low=0, high=e, size=(B, L)), use_cpu) for e in Es]
        xws = [to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(T)]
        xws_acc_type = copy.deepcopy(xws)

        if fp16 and not use_cpu:
            # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16.
            xws = [xw.half() for xw in xws]

        fs = (
            [b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)]
            if not weighted
            else [
                b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu)
                for (b, x, xw) in zip(bs, xs, xws)
            ]
        )
        f = torch.cat([f.view(B, -1) for f in fs], dim=1)

        cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (
                    E,
                    D,
                    split_table_batched_embeddings_ops.EmbeddingLocation(M),
                    compute_device,
                )
                for (E, D, M) in zip(Es, Ds, managed)
            ],
            fp16=fp16,
            optimizer=OptimType.EXACT_SGD,
            learning_rate=0.05,
            cache_algorithm=cache_algorithm,
            pooling_mode=pooling_mode,
        )
        # NOTE: test TorchScript-compatible!
        cc = torch.jit.script(cc)

        for t in range(T):
            cc.split_embedding_weights()[t].data.copy_(bs[t].weight)

        x = torch.cat([x.view(1, B, L) for x in xs], dim=0)
        xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0)

        (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu)
        fc2 = (
            cc(indices, offsets)
            if not weighted
            else cc(indices, offsets, to_device(xw.contiguous().view(-1), use_cpu))
        )
        torch.testing.assert_allclose(
            fc2.float(),
            f.float(),
            atol=8.0e-3 if fp16 else 1.0e-5,
            rtol=8.0e-3 if fp16 else 1.0e-5,
        )
Example #5
0
    def test_backward_optimizers(  # noqa C901
        self,
        T,
        D,
        B,
        log_E,
        L,
        stochastic_rounding,
        weighted,
        mixed,
        optimizer,
        long_segments,
        pooling_mode,
        use_cpu,
    ):
        # NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
        assume(not use_cpu or T * B * L * D <= 2048)

        assume(
            pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM
            or not weighted
        )
        mode = (
            "sum"
            if pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM
            else "mean"
        )

        E = int(10 ** log_E)
        if use_cpu:
            D = (D + 15) // 16 * 4
        else:
            D = D * 4
        if not mixed:
            Ds = [D] * T
            Es = [E] * T
        else:
            Ds = [
                div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
                for _ in range(T)
            ]
            Es = [
                np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T)
            ]
        compute_device = split_table_batched_embeddings_ops.ComputeDevice.CUDA
        if use_cpu:
            managed = [
                split_table_batched_embeddings_ops.EmbeddingLocation.HOST
            ] * T
            compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU
        else:
            managed = [
                np.random.choice(
                    [
                        split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
                        split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
                    ]
                )
                for _ in range(T)
            ]
        bs = [
            to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu)
            for (E, D) in zip(Es, Ds)
        ]

        xs = [
            to_device(torch.from_numpy(
                np.random.choice(range(e), size=(B, L), replace=True).astype(np.int64)
            ), use_cpu)
            for e in Es
        ]
        if long_segments and L > 0:
            for x in xs:
                x[:, 0] = 0

        xws = [to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(T)]
        xws_acc_type = copy.deepcopy(xws)

        fs = (
            [b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)]
            if not weighted
            else [
                b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu)
                for (b, x, xw) in zip(bs, xs, xws)
            ]
        )
        gos = [torch.randn_like(f) for f in fs]
        [f.backward(go) for (f, go) in zip(fs, gos)]
        # do SGD update

        optimizer_kwargs = {"learning_rate": 0.5}
        (lr, eps, beta1, beta2, weight_decay, momentum, eta) = (
            0.5,
            1e-4,
            0.9,
            0.99,
            0.01,
            0.9,
            0.01,
        )
        if optimizer in (OptimType.EXACT_ROWWISE_ADAGRAD, OptimType.EXACT_ADAGRAD):
            optimizer_kwargs["eps"] = eps

        if optimizer in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM):
            optimizer_kwargs["eps"] = eps
            optimizer_kwargs["beta1"] = beta1
            optimizer_kwargs["beta2"] = beta2
            optimizer_kwargs["weight_decay"] = weight_decay

        if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB):
            optimizer_kwargs["eps"] = eps
            optimizer_kwargs["beta1"] = beta1
            optimizer_kwargs["beta2"] = beta2
            optimizer_kwargs["weight_decay"] = weight_decay

        if optimizer == OptimType.LARS_SGD:
            optimizer_kwargs["weight_decay"] = weight_decay
            optimizer_kwargs["momentum"] = momentum
            optimizer_kwargs["eta"] = eta

        cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
            [(E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)],
            optimizer=optimizer,
            stochastic_rounding=stochastic_rounding,
            pooling_mode=pooling_mode,
            **optimizer_kwargs,
        )

        for t in range(T):
            cc.split_embedding_weights()[t].data.copy_(bs[t].weight)

        x = torch.cat([x.view(1, B, L) for x in xs], dim=0)
        xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0)

        (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu)
        fc2 = (
            cc(indices, offsets)
            if not weighted
            else cc(indices, offsets, to_device(xw.contiguous().view(-1), use_cpu))
        )
        fc2.backward(torch.cat([go.view(B, -1) for go in gos], dim=1))
        cc.flush()

        split_optimizer_states = cc.split_optimizer_states()
        assert len(split_optimizer_states) == T
        split_weights = cc.split_embedding_weights()

        if optimizer in (OptimType.EXACT_ROWWISE_ADAGRAD, OptimType.EXACT_ADAGRAD):
            rowwise = optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
            for t in range(T):
                (m1,) = split_optimizer_states[t]
                m1_ref = (
                    bs[t].weight.grad.to_dense().pow(2)
                    if not rowwise
                    else bs[t].weight.grad.to_dense().pow(2).mean(dim=1)
                )
                torch.testing.assert_allclose(
                    m1.float(), m1_ref.float(), atol=1.0e-4, rtol=1.0e-4
                )
                weights_new = split_weights[t]
                weights_ref = bs[t].weight - lr * bs[t].weight.grad.to_dense() / (
                    torch.sqrt(
                        m1_ref if not rowwise else m1_ref.view(m1_ref.numel(), 1)
                    )
                    + eps
                )
                # TODO: why is tolerance off here?
                torch.testing.assert_allclose(
                    weights_new.float(), weights_ref.float(), atol=1.0e-2, rtol=1.0e-2
                )

        if optimizer in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM):
            rowwise = optimizer == OptimType.PARTIAL_ROWWISE_ADAM
            for t in range(T):
                (m1, m2) = split_optimizer_states[t]
                m2_ref = (
                    bs[t].weight.grad.to_dense().pow(2)
                    if not rowwise
                    else bs[t].weight.grad.to_dense().pow(2).mean(dim=1)
                ) * (1.0 - beta2)
                torch.testing.assert_allclose(m2, m2_ref, atol=1.0e-4, rtol=1.0e-4)
                m1_ref = bs[t].weight.grad.to_dense() * (1.0 - beta1)
                torch.testing.assert_allclose(m1, m1_ref, atol=1.0e-4, rtol=1.0e-4)
                iter_ = cc.iter.item()
                v_hat_t = m2_ref / (1 - beta2 ** iter_)
                v_hat_t = v_hat_t if not rowwise else v_hat_t.view(v_hat_t.numel(), 1)
                m_hat_t = m1_ref / (1 - beta1 ** iter_)
                weights_new = split_weights[t]
                weights_ref = (
                    torch.addcdiv(
                        bs[t].weight,
                        value=-lr,
                        tensor1=m_hat_t,
                        tensor2=v_hat_t.sqrt_().add_(eps),
                    )
                    - lr * weight_decay * bs[t].weight
                )
                torch.testing.assert_allclose(
                    weights_new.index_select(dim=0, index=x[t].view(-1)),
                    weights_ref.index_select(dim=0, index=x[t].view(-1)),
                    atol=1.0e-3,
                    rtol=1.0e-3,
                )

        if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB):
            rowwise = optimizer == OptimType.PARTIAL_ROWWISE_LAMB
            for t in range(T):
                (m1, m2) = split_optimizer_states[t]
                m2_ref = (
                    bs[t].weight.grad.to_dense().pow(2)
                    if not rowwise
                    else bs[t].weight.grad.to_dense().pow(2).mean(dim=1)
                ) * (1.0 - beta2)
                torch.testing.assert_allclose(m2, m2_ref, atol=1.0e-4, rtol=1.0e-4)
                m1_ref = bs[t].weight.grad.to_dense() * (1.0 - beta1)
                torch.testing.assert_allclose(m1, m1_ref, atol=1.0e-4, rtol=1.0e-4)
                iter_ = cc.iter.item()
                v_hat_t = m2_ref / (1 - beta2 ** iter_)
                v_hat_t = v_hat_t if not rowwise else v_hat_t.view(v_hat_t.numel(), 1)
                m_hat_t = m1_ref / (1 - beta1 ** iter_)
                rtw = (m_hat_t / (torch.sqrt(v_hat_t) + eps)) + weight_decay * bs[
                    t
                ].weight
                true_ratio = torch.linalg.norm(bs[t].weight, dim=1, ord=2).view(
                    m1.shape[0], 1
                ) / torch.linalg.norm(rtw, dim=1, ord=2).view(m1.shape[0], 1)
                weights_new = split_weights[t]
                weights_ref = bs[t].weight - lr * true_ratio * rtw
                torch.testing.assert_allclose(
                    weights_new.index_select(dim=0, index=x[t].view(-1)),
                    weights_ref.index_select(dim=0, index=x[t].view(-1)),
                    atol=1.0e-3,
                    rtol=1.0e-3,
                )

        if optimizer == OptimType.LARS_SGD:
            for t in range(T):
                (m1,) = split_optimizer_states[t]
                weight_norm = torch.linalg.norm(bs[t].weight, dim=1, ord=2).view(
                    m1.shape[0], 1
                )
                grad_norm = torch.linalg.norm(
                    bs[t].weight.grad.to_dense(), dim=1, ord=2
                ).view(m1.shape[0], 1)
                adjusted_lr = (
                    lr * eta * weight_norm / (grad_norm + weight_decay * weight_norm)
                )
                m1_ref = adjusted_lr * (
                    bs[t].weight.grad.to_dense() + weight_decay * bs[t].weight
                )

                torch.testing.assert_allclose(
                    m1.index_select(dim=0, index=x[t].view(-1)),
                    m1_ref.index_select(dim=0, index=x[t].view(-1)),
                    atol=1.0e-4,
                    rtol=1.0e-4,
                )
                weights_new = split_weights[t]
                weights_ref = bs[t].weight - m1_ref
                torch.testing.assert_allclose(
                    weights_new.index_select(dim=0, index=x[t].view(-1)),
                    weights_ref.index_select(dim=0, index=x[t].view(-1)),
                    atol=1.0e-4,
                    rtol=1.0e-4,
                )
def cache(  # noqa C901
    alpha: float,
    bag_size: int,
    batch_size: int,
    cache_algorithm: str,
    cache_sets: int,
    embedding_dim: int,
    fp16: bool,
    iters: int,
    long_index: bool,
    mixed: bool,
    num_embeddings: int,
    num_tables: int,
    reuse: float,
    weighted: bool,
) -> None:
    np.random.seed(42)

    optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
    B = batch_size
    D = embedding_dim
    L = bag_size
    E = num_embeddings
    T = num_tables
    cache_alg = (
        split_table_batched_embeddings_ops.CacheAlgorithm.LRU
        if cache_algorithm == "lru"
        else split_table_batched_embeddings_ops.CacheAlgorithm.LFU
    )
    if mixed:
        Ds = [
            div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
            for _ in range(T)
        ]
        D = np.average(Ds)
    else:
        Ds = [D] * T

    emb_nc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
        [
            (
                E,
                d,
                split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
                split_table_batched_embeddings_ops.ComputeDevice.CUDA,
            )
            for d in Ds
        ],
        optimizer=optimizer,
        fp16=fp16,
    ).cuda()
    emb = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
        [
            (
                E,
                d,
                split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED_CACHING,
                split_table_batched_embeddings_ops.ComputeDevice.CUDA,
            )
            for d in Ds
        ],
        optimizer=optimizer,
        fp16=fp16,
        cache_sets=cache_sets,
        cache_algorithm=cache_alg,
    ).cuda()
    nparams = sum(w.numel() for w in emb.split_embedding_weights())
    logging.info(
        f"Embedding tables: {E * T} rows, {nparams / 1.0e9: .2f} GParam, "
        f"{nparams * (2 if fp16 else 4)  / 1.0e6: .2f}MB"
    )
    logging.info(
        f"Accessed weights per batch: {B * T * L} rows, "
        f"{B * T * L * D * (2 if fp16 else 4) / 1.0e6: .2f}MB"
    )

    requests = generate_requests(
        2 * iters, B, T, L, E, reuse=reuse, alpha=alpha, weighted=weighted
    )
    warmup_requests, requests = requests[:iters], requests[iters:]
    grad_output = torch.randn(B, sum(Ds)).cuda()

    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb_nc(
            indices.long(), offsets.long(), per_sample_weights
        ).backward(grad_output),
    )
    logging.info(
        f"ForwardBackward (UVM), B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, "
        f"BW: {3 * (2 if fp16 else 4) * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f}GB/s, "
        f"T: {time_per_iter * 1.0e6:.0f}us"
    )

    # warm up
    for indices, offsets, _ in warmup_requests:
        emb.prefetch(indices.long(), offsets.long())
    # get cache miss rate (forward and backward) and exchanged cache lines (prefetch)
    cache_misses = []
    exchanged_cache_lines = []
    NOT_FOUND = np.iinfo(np.int32).max
    for indices, offsets, _ in requests:
        # pyre-fixme[16]
        old_lxu_cache_state = emb.lxu_cache_state.clone()
        emb.forward(indices.long(), offsets.long())
        exchanged_cache_lines.append(
            (emb.lxu_cache_state != old_lxu_cache_state).sum().item()
        )
        cache_misses.append((emb.lxu_cache_locations == NOT_FOUND).sum().item())
    logging.info(
        f"Exchanged cache lines -- mean: {sum(exchanged_cache_lines)/len(requests): .2f}, "
        f"max: {max(exchanged_cache_lines)}, min: {min(exchanged_cache_lines)}"
    )
    logging.info(
        f"Cache miss -- mean: {sum(cache_misses)/len(requests)}, "
        f"max: {max(cache_misses)}, min: {min(cache_misses)}"
    )

    # benchmark prefetch
    emb.reset_cache_states()
    for indices, offsets, _ in warmup_requests:
        emb.prefetch(indices, offsets)
    prefetch_time, forward_backward_time = benchmark_pipelined_requests(
        requests,
        lambda indices, offsets, indices_weights: emb.prefetch(indices, offsets),
        lambda indices, offsets, indices_weights: emb.forward(
            indices, offsets, indices_weights, prefetch=False
        ).backward(grad_output),
    )
    e2e_time = prefetch_time + forward_backward_time

    logging.info(
        f"ForwardBackward (LXU), reuse: {reuse}, alpha: {alpha}, B: {B}, "
        f"E: {E}, T: {T}, D: {D}, L: {L}, "
        f"BW: {3 * (2 if fp16 else 4) * B * sum(Ds) * L / e2e_time / 1.0e9: .2f}GB/s, "
        f"Tprefetch: {prefetch_time * 1.0e6:.0f}us, "
        f"{2 * sum(exchanged_cache_lines) * (2 if fp16 else 4) * D / prefetch_time / len(requests) / 1.0e9: .2f} GB/s, "
        f"Tfwdbwd: {forward_backward_time * 1.0e6:.0f}us, "
        f"{3 * (2 if fp16 else 4) * B * sum(Ds) * L / forward_backward_time / 1.0e9: .2f} GB/s, "
        f"Te2e: {e2e_time * 1.0e6:.0f}us, "
    )
def uvm(
    alpha: bool,
    bag_size: int,
    batch_size: int,
    embedding_dim: int,
    fp16: bool,
    iters: int,
    mixed: bool,
    num_embeddings: int,
    num_tables: int,
    reuse: float,
    uvm_tables: int,
    uvm_bag_size: int,
    weighted: bool,
) -> None:

    np.random.seed(42)
    B = batch_size
    D = embedding_dim
    L = bag_size
    E = num_embeddings
    T = num_tables
    T_uvm = uvm_tables
    assert T_uvm <= T
    T_gpu = T - T_uvm
    L_uvm = uvm_bag_size

    if mixed:
        Ds = [
            div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
            for _ in range(T)
        ]
        D = np.average(Ds)
    else:
        Ds = [D] * T
    emb_uvm = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
        [
            (
                E,
                d,
                split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
                split_table_batched_embeddings_ops.ComputeDevice.CUDA,
            )
            for d in Ds[:T_uvm]
        ],
        fp16=fp16,
    ).cuda()
    emb_gpu = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
        [
            (
                E,
                d,
                split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
                split_table_batched_embeddings_ops.ComputeDevice.CUDA,
            )
            for d in Ds[T_uvm:]
        ],
        fp16=fp16,
    ).cuda()
    emb_mixed = (
        split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
            [
                (
                    E,
                    d,
                    managed_option,
                    split_table_batched_embeddings_ops.ComputeDevice.CUDA,
                )
                for (d, managed_option) in zip(
                    Ds,
                    [split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED]
                    * T_uvm
                    + [split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE]
                    * T_gpu,
                )
            ],
            fp16=fp16,
        ).cuda()
    )
    requests_uvm = generate_requests(
        iters,
        B,
        T_uvm,
        L_uvm,
        E,
        reuse=reuse,
        alpha=alpha,
        fp16=fp16,
        weighted=weighted,
    )
    requests_gpu = generate_requests(
        iters, B, T_gpu, L, E, reuse=reuse, alpha=alpha, fp16=fp16, weighted=False
    )
    requests = []
    for rs_uvm, rs_gpu in zip(requests_uvm, requests_gpu):
        indices = torch.cat([rs_uvm[0], rs_gpu[0]])
        lengths = [L_uvm] * (T_uvm * B) + [L] * (T_gpu * B)
        offsets = torch.tensor(([0] + np.cumsum(lengths).tolist())).int().cuda()
        # pyre-fixme[6]
        per_sample_weights = torch.cat([rs_uvm[2], rs_gpu[2]]) if weighted else None
        requests.append((indices, offsets, per_sample_weights))

    # forward
    time_per_iter = benchmark_requests(
        requests_gpu,
        lambda indices, offsets, per_sample_weights: emb_gpu.forward(
            indices.long(),
            offsets.long(),
            per_sample_weights,
        ),
    )
    logging.info(
        f"GPU Forward, B: {B}, "
        f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, "
        f"BW: {(2 if fp16 else 4) * B * T_gpu * L * D / time_per_iter / 1.0e9: .2f}GB/s, "  # noqa: B950
        f"T: {time_per_iter * 1.0e6:.0f}us"
    )

    time_per_iter = benchmark_requests(
        requests_uvm,
        lambda indices, offsets, per_sample_weights: emb_uvm.forward(
            indices.long(),
            offsets.long(),
            per_sample_weights,
        ),
    )
    logging.info(
        f"UVM Forward, B: {B}, "
        f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, "
        f"BW: {(2 if fp16 else 4) * B * T_gpu * L * D / time_per_iter / 1.0e9: .2f}GB/s, "  # noqa: B950
        f"T: {time_per_iter * 1.0e6:.0f}us"
    )

    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb_mixed.forward(
            indices.long(),
            offsets.long(),
            per_sample_weights,
        ),
    )
    logging.info(
        f"Mixed Forward, B: {B}, "
        f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, "
        f"BW: {(2 if fp16 else 4) * B * T_gpu * L * D / time_per_iter / 1.0e9: .2f}GB/s, "  # noqa: B950
        f"T: {time_per_iter * 1.0e6:.0f}us"
    )
def device(  # noqa C901
    alpha: float,
    bag_size: int,
    batch_size: int,
    embedding_dim: int,
    fp16: bool,
    iters: int,
    managed: str,
    mixed: bool,
    num_embeddings: int,
    num_tables: int,
    reuse: float,
    row_wise: bool,
    weighted: bool,
    weighted_num_requires_grad: Optional[int],
) -> None:
    np.random.seed(42)
    B = batch_size
    D = embedding_dim
    L = bag_size
    E = num_embeddings
    T = num_tables
    if weighted_num_requires_grad:
        assert weighted_num_requires_grad <= T
        weighted_requires_grad_tables = np.random.choice(
            T, replace=False, size=(weighted_num_requires_grad,)
        ).tolist()
        feature_requires_grad = (
            torch.tensor(
                [1 if t in weighted_requires_grad_tables else 0 for t in range(T)]
            )
            .cuda()
            .int()
        )
    else:
        feature_requires_grad = None
    if mixed:
        Ds = [
            div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
            for _ in range(T)
        ]
        D = np.average(Ds)
    else:
        Ds = [D] * T
    optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD

    if managed == "device":
        managed_option = split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
    else:
        managed_option = split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED
    emb = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
        [(E, d, managed_option, split_table_batched_embeddings_ops.ComputeDevice.CUDA) for d in Ds],
        optimizer=optimizer,
        learning_rate=0.1,
        eps=0.1,
        stochastic_rounding=False,
        fp16=fp16,
    ).cuda()

    nparams = sum(w.numel() for w in emb.split_embedding_weights())
    logging.info(
        f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, "
        f"{nparams * (2 if fp16 else 4)  / 1.0e9: .2f}GB"
    )
    logging.info(
        f"Accessed weights per batch: {B * T * L * D * (2 if fp16 else 4) / 1.0e6: .2f}MB"
    )

    requests = generate_requests(
        iters, B, T, L, E, reuse=reuse, alpha=alpha, fp16=fp16, weighted=weighted
    )

    # forward
    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb.forward(
            indices.long(),
            offsets.long(),
            per_sample_weights,
            feature_requires_grad=feature_requires_grad,
        ),
    )
    logging.info(
        f"Forward, B: {B}, "
        f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
        f"BW: {(2 if fp16 else 4) * B * T * L * D / time_per_iter / 1.0e9: .2f}GB/s, "  # noqa: B950
        f"T: {time_per_iter * 1.0e6:.0f}us"
    )

    grad_output = torch.randn(B, sum(Ds)).cuda()
    # backward
    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb(
            indices.long(),
            offsets.long(),
            per_sample_weights,
            feature_requires_grad=feature_requires_grad,
        ).backward(grad_output),
    )
    logging.info(
        f"ForwardBackward, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, "
        f"BW: {3 * (2 if fp16 else 4) * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f}GB/s, "
        f"T: {time_per_iter * 1.0e6:.0f}us"
    )