示例#1
0
    def test_cache_pipeline(self, T, D, B, log_E, L, mixed, cache_algorithm):
        iters = 3
        E = int(10**log_E)
        D = D * 4
        if not mixed:
            Ds = [D] * T
            Es = [E] * T
        else:
            Ds = [
                div_round_up(
                    np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
                for _ in range(T)
            ]
            Es = [
                np.random.randint(low=int(0.5 * E), high=int(2.0 * E))
                for _ in range(T)
            ]
        managed = [
            split_table_batched_embeddings_ops.EmbeddingLocation.
            MANAGED_CACHING
        ] * T
        if mixed:
            average_D = sum(Ds) // T
            for t, d in enumerate(Ds):
                managed[t] = (
                    split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
                    if d < average_D else managed[t])
        cc_ref = (
            split_table_batched_embeddings_ops.
            SplitTableBatchedEmbeddingBagsCodegen([(
                E,
                D,
                split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
                split_table_batched_embeddings_ops.ComputeDevice.CUDA,
            ) for (E, D) in zip(Es, Ds)], ))
        cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
            [(E, D, M, split_table_batched_embeddings_ops.ComputeDevice.CUDA)
             for (E, D, M) in zip(Es, Ds, managed)],
            cache_algorithm=cache_algorithm,
        )
        for t in range(T):
            assert (cc.split_embedding_weights()[t].size() ==
                    cc_ref.split_embedding_weights()[t].size())
            cc.split_embedding_weights()[t].data.copy_(
                cc_ref.split_embedding_weights()[t])

        requests = generate_requests(iters, B, T, L, min(Es), reuse=0.1)
        grad_output = torch.randn(B, sum(Ds)).cuda()

        for indices, offsets, _ in requests:
            output = cc(indices, offsets)
            output_ref = cc_ref(indices, offsets)
            torch.testing.assert_allclose(output, output_ref)
            output.backward(grad_output)
            output_ref.backward(grad_output)
        cc.flush()
        for t in range(T):
            torch.testing.assert_allclose(cc.split_embedding_weights()[t],
                                          cc_ref.split_embedding_weights()[t])
    def __init__(
        self,
        emb_dim,
        num_tables,
        num_rows,
        use_cpu,
    ) -> None:
        super().__init__()
        pooling_mode = split_table_batched_embeddings_ops.PoolingMode.SUM
        Ds = [emb_dim] * num_tables
        Es = [num_rows] * num_tables

        device = (
            split_table_batched_embeddings_ops.ComputeDevice.CPU
            if use_cpu
            else split_table_batched_embeddings_ops.ComputeDevice.CUDA
        )
        loc = (
            split_table_batched_embeddings_ops.EmbeddingLocation.HOST
            if use_cpu
            else split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
        )

        self.emb_module = (
            split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
                embedding_specs=[
                    (
                        E,
                        D,
                        loc,
                        device,
                    )
                    for (E, D) in zip(Es, Ds)
                ],
                weights_precision=SparseType.FP32,
                optimizer=OptimType.EXACT_SGD,
                learning_rate=0.05,
                pooling_mode=pooling_mode,
            )
        )

        self.emb_module.init_embedding_weights_uniform(
            -EMB_WEIGHT_UNIFORM_INIT_BOUND, +EMB_WEIGHT_UNIFORM_INIT_BOUND
        )
示例#3
0
def cache(  # noqa C901
    alpha: float,
    bag_size: int,
    batch_size: int,
    cache_algorithm: str,
    cache_sets: int,
    embedding_dim: int,
    weights_precision: SparseType,
    stoc: bool,
    iters: int,
    long_index: bool,
    mixed: bool,
    num_embeddings: int,
    num_tables: int,
    reuse: float,
    weighted: bool,
) -> None:
    np.random.seed(42)

    optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
    B = batch_size
    D = embedding_dim
    L = bag_size
    E = num_embeddings
    T = num_tables
    cache_alg = (split_table_batched_embeddings_ops.CacheAlgorithm.LRU
                 if cache_algorithm == "lru" else
                 split_table_batched_embeddings_ops.CacheAlgorithm.LFU)
    if mixed:
        Ds = [
            div_round_up(
                np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
            for _ in range(T)
        ]
        D = np.average(Ds)
    else:
        Ds = [D] * T

    emb_nc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
        [(
            E,
            d,
            split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
            split_table_batched_embeddings_ops.ComputeDevice.CUDA,
        ) for d in Ds],
        optimizer=optimizer,
        weights_precision=weights_precision,
        stochastic_rounding=stoc,
    ).cuda()

    if weights_precision == SparseType.INT8:
        emb_nc.init_embedding_weights_uniform(-0.0003, 0.0003)

    emb = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
        [(
            E,
            d,
            split_table_batched_embeddings_ops.EmbeddingLocation.
            MANAGED_CACHING,
            split_table_batched_embeddings_ops.ComputeDevice.CUDA,
        ) for d in Ds],
        optimizer=optimizer,
        weights_precision=weights_precision,
        stochastic_rounding=stoc,
        cache_sets=cache_sets,
        cache_algorithm=cache_alg,
    ).cuda()

    if weights_precision == SparseType.INT8:
        emb.init_embedding_weights_uniform(-0.0003, 0.0003)

    nparams = sum(w.numel() for w in emb.split_embedding_weights())
    param_size_multiplier = PRECISION_SIZE_MULTIPLIER[weights_precision]
    logging.info(
        f"Embedding tables: {E * T} rows, {nparams / 1.0e9: .2f} GParam, "
        f"{nparams * param_size_multiplier  / 1.0e6: .2f}MB")
    logging.info(f"Accessed weights per batch: {B * T * L} rows, "
                 f"{B * T * L * D * param_size_multiplier / 1.0e6: .2f}MB")

    requests = generate_requests(2 * iters,
                                 B,
                                 T,
                                 L,
                                 E,
                                 reuse=reuse,
                                 alpha=alpha,
                                 weighted=weighted)
    warmup_requests, requests = requests[:iters], requests[iters:]
    grad_output = torch.randn(B, sum(Ds)).cuda()

    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb_nc(
            indices.long(), offsets.long(), per_sample_weights).backward(
                grad_output),
    )
    logging.info(
        f"ForwardBackward (UVM), B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, "
        f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f}GB/s, "
        f"T: {time_per_iter * 1.0e6:.0f}us")

    # warm up
    for indices, offsets, _ in warmup_requests:
        emb.forward(indices.long(), offsets.long())
    # get cache miss rate (forward and backward) and exchanged cache lines (prefetch)
    cache_misses = []
    exchanged_cache_lines = []
    NOT_FOUND = -1
    for indices, offsets, _ in requests:
        # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no attribute
        #  `lxu_cache_state`.
        old_lxu_cache_state = emb.lxu_cache_state.clone()
        emb.prefetch(indices.long(), offsets.long())
        exchanged_cache_lines.append(
            (emb.lxu_cache_state != old_lxu_cache_state).sum().item())
        cache_misses.append(
            (emb.lxu_cache_locations_list[0] == NOT_FOUND).sum().item())
        emb.forward(indices.long(), offsets.long())
    logging.info(
        f"Exchanged cache lines -- mean: {sum(exchanged_cache_lines)/len(requests): .2f}, "
        f"max: {max(exchanged_cache_lines)}, min: {min(exchanged_cache_lines)}"
    )
    logging.info(f"Cache miss -- mean: {sum(cache_misses)/len(requests)}, "
                 f"max: {max(cache_misses)}, min: {min(cache_misses)}")

    # benchmark prefetch
    emb.reset_cache_states()
    for indices, offsets, _ in warmup_requests:
        emb.forward(indices, offsets)
    prefetch_time, forward_backward_time = benchmark_pipelined_requests(
        requests,
        lambda indices, offsets, indices_weights: emb.prefetch(
            indices, offsets),
        lambda indices, offsets, indices_weights: emb.forward(
            indices, offsets, indices_weights).backward(grad_output),
    )
    e2e_time = prefetch_time + forward_backward_time

    logging.info(
        f"ForwardBackward (LXU), reuse: {reuse}, alpha: {alpha}, B: {B}, "
        f"E: {E}, T: {T}, D: {D}, L: {L}, "
        f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / e2e_time / 1.0e9: .2f}GB/s, "
        f"Tprefetch: {prefetch_time * 1.0e6:.0f}us, "
        f"{2 * sum(exchanged_cache_lines) * param_size_multiplier * D / prefetch_time / len(requests) / 1.0e9: .2f} GB/s, "
        f"Tfwdbwd: {forward_backward_time * 1.0e6:.0f}us, "
        f"{3 * param_size_multiplier * B * sum(Ds) * L / forward_backward_time / 1.0e9: .2f} GB/s, "
        f"Te2e: {e2e_time * 1.0e6:.0f}us, ")
示例#4
0
def uvm(
    alpha: bool,
    bag_size: int,
    batch_size: int,
    embedding_dim: int,
    weights_precision: SparseType,
    stoc: bool,
    iters: int,
    mixed: bool,
    num_embeddings: int,
    num_tables: int,
    reuse: float,
    uvm_tables: int,
    uvm_bag_size: int,
    weighted: bool,
) -> None:

    np.random.seed(42)
    B = batch_size
    D = embedding_dim
    L = bag_size
    E = num_embeddings
    T = num_tables
    T_uvm = uvm_tables
    assert T_uvm <= T
    T_gpu = T - T_uvm
    L_uvm = uvm_bag_size

    if mixed:
        Ds = [
            div_round_up(
                np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
            for _ in range(T)
        ]
        D = np.average(Ds)
    else:
        Ds = [D] * T
    emb_uvm = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
        [(
            E,
            d,
            split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
            split_table_batched_embeddings_ops.ComputeDevice.CUDA,
        ) for d in Ds[:T_uvm]],
        weights_precision=weights_precision,
        stochastic_rounding=stoc,
    ).cuda()

    if weights_precision == SparseType.INT8:
        emb_uvm.init_embedding_weights_uniform(-0.0003, 0.0003)

    emb_gpu = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
        [(
            E,
            d,
            split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
            split_table_batched_embeddings_ops.ComputeDevice.CUDA,
        ) for d in Ds[T_uvm:]],
        weights_precision=weights_precision,
        stochastic_rounding=stoc,
    ).cuda()

    if weights_precision == SparseType.INT8:
        emb_gpu.init_embedding_weights_uniform(-0.0003, 0.0003)

    emb_mixed = (
        split_table_batched_embeddings_ops.
        SplitTableBatchedEmbeddingBagsCodegen(
            [(
                E,
                d,
                managed_option,
                split_table_batched_embeddings_ops.ComputeDevice.CUDA,
            ) for (d, managed_option) in zip(
                Ds,
                [split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED
                 ] * T_uvm +
                [split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE] *
                T_gpu,
            )],
            weights_precision=weights_precision,
            stochastic_rounding=stoc,
        ).cuda())

    if weights_precision == SparseType.INT8:
        emb_mixed.init_embedding_weights_uniform(-0.0003, 0.0003)

    requests_uvm = generate_requests(
        iters,
        B,
        T_uvm,
        L_uvm,
        E,
        reuse=reuse,
        alpha=alpha,
        weights_precision=weights_precision,
        weighted=weighted,
    )
    requests_gpu = generate_requests(
        iters,
        B,
        T_gpu,
        L,
        E,
        reuse=reuse,
        alpha=alpha,
        weights_precision=weights_precision,
        weighted=False,
    )
    requests = []
    for rs_uvm, rs_gpu in zip(requests_uvm, requests_gpu):
        indices = torch.cat([rs_uvm[0], rs_gpu[0]])
        lengths = [L_uvm] * (T_uvm * B) + [L] * (T_gpu * B)
        offsets = torch.tensor(
            ([0] + np.cumsum(lengths).tolist())).int().cuda()
        per_sample_weights = None
        if weighted:
            assert (this_rs_uvm_weights := rs_uvm[2]) is not None
            assert (this_rs_gpu_weights := rs_gpu[2]) is not None
            per_sample_weights = torch.cat(
                [this_rs_uvm_weights, this_rs_gpu_weights])
        requests.append((indices, offsets, per_sample_weights))

    # forward
    time_per_iter = benchmark_requests(
        requests_gpu,
        lambda indices, offsets, per_sample_weights: emb_gpu.forward(
            indices.long(),
            offsets.long(),
            per_sample_weights,
        ),
    )
    param_size_multiplier = PRECISION_SIZE_MULTIPLIER[weights_precision]

    logging.info(
        f"GPU Forward, B: {B}, "
        f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, "
        f"BW: {param_size_multiplier * B * T_gpu * L * D / time_per_iter / 1.0e9: .2f}GB/s, "  # noqa: B950
        f"T: {time_per_iter * 1.0e6:.0f}us")

    time_per_iter = benchmark_requests(
        requests_uvm,
        lambda indices, offsets, per_sample_weights: emb_uvm.forward(
            indices.long(),
            offsets.long(),
            per_sample_weights,
        ),
    )
    logging.info(
        f"UVM Forward, B: {B}, "
        f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, "
        f"BW: {param_size_multiplier * B * T_gpu * L * D / time_per_iter / 1.0e9: .2f}GB/s, "  # noqa: B950
        f"T: {time_per_iter * 1.0e6:.0f}us")

    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb_mixed.forward(
            indices.long(),
            offsets.long(),
            per_sample_weights,
        ),
    )
    logging.info(
        f"Mixed Forward, B: {B}, "
        f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, "
        f"BW: {param_size_multiplier * B * T_gpu * L * D / time_per_iter / 1.0e9: .2f}GB/s, "  # noqa: B950
        f"T: {time_per_iter * 1.0e6:.0f}us")
示例#5
0
def device(  # noqa C901
    alpha: float,
    bag_size: int,
    batch_size: int,
    embedding_dim: int,
    weights_precision: SparseType,
    stoc: bool,
    iters: int,
    managed: str,
    mixed: bool,
    num_embeddings: int,
    num_tables: int,
    reuse: float,
    row_wise: bool,
    weighted: bool,
    weighted_num_requires_grad: Optional[int],
) -> None:
    np.random.seed(42)
    B = batch_size
    D = embedding_dim
    L = bag_size
    E = num_embeddings
    T = num_tables
    if weighted_num_requires_grad:
        assert weighted_num_requires_grad <= T
        weighted_requires_grad_tables = np.random.choice(
            T, replace=False, size=(weighted_num_requires_grad, )).tolist()
        feature_requires_grad = (torch.tensor([
            1 if t in weighted_requires_grad_tables else 0 for t in range(T)
        ]).cuda().int())
    else:
        feature_requires_grad = None
    if mixed:
        Ds = [
            div_round_up(
                np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
            for _ in range(T)
        ]
        D = np.average(Ds)
    else:
        Ds = [D] * T
    optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD

    if managed == "device":
        managed_option = split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
    else:
        managed_option = split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED

    emb = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
        [(
            E,
            d,
            managed_option,
            split_table_batched_embeddings_ops.ComputeDevice.CUDA,
        ) for d in Ds],
        optimizer=optimizer,
        learning_rate=0.1,
        eps=0.1,
        weights_precision=weights_precision,
        stochastic_rounding=stoc,
    ).cuda()
    if weights_precision == SparseType.INT8:
        emb.init_embedding_weights_uniform(-0.0003, 0.0003)

    nparams = sum(w.numel() for w in emb.split_embedding_weights())

    param_size_multiplier = PRECISION_SIZE_MULTIPLIER[weights_precision]

    logging.info(f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, "
                 f"{nparams * param_size_multiplier / 1.0e9: .2f}GB")
    logging.info(
        f"Accessed weights per batch: {B * T * L * D * param_size_multiplier / 1.0e6: .2f}MB"
    )

    requests = generate_requests(
        iters,
        B,
        T,
        L,
        E,
        reuse=reuse,
        alpha=alpha,
        weights_precision=weights_precision,
        weighted=weighted,
    )

    # forward
    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb.forward(
            indices.long(),
            offsets.long(),
            per_sample_weights,
            feature_requires_grad=feature_requires_grad,
        ),
    )
    logging.info(
        f"Forward, B: {B}, "
        f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
        f"BW: {param_size_multiplier * B * T * L * D / time_per_iter / 1.0e9: .2f}GB/s, "  # noqa: B950
        f"T: {time_per_iter * 1.0e6:.0f}us")

    grad_output = torch.randn(B, sum(Ds)).cuda()
    # backward
    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb(
            indices.long(),
            offsets.long(),
            per_sample_weights,
            feature_requires_grad=feature_requires_grad,
        ).backward(grad_output),
    )
    logging.info(
        f"ForwardBackward, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, "
        f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f}GB/s, "
        f"T: {time_per_iter * 1.0e6:.0f}us")
示例#6
0
    def test_backward_optimizers(  # noqa C901
        self,
        T,
        D,
        B,
        log_E,
        L,
        stochastic_rounding,
        weighted,
        mixed,
        optimizer,
        long_segments,
        pooling_mode,
        use_cpu,
    ):
        # NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
        assume(not use_cpu or T * B * L * D <= 2048)

        assume(
            pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM
            or not weighted)
        mode = ("sum" if pooling_mode
                == split_table_batched_embeddings_ops.PoolingMode.SUM else
                "mean")

        E = int(10**log_E)
        if use_cpu:
            D = (D + 15) // 16 * 4
        else:
            D = D * 4
        if not mixed:
            Ds = [D] * T
            Es = [E] * T
        else:
            Ds = [
                div_round_up(
                    np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
                for _ in range(T)
            ]
            Es = [
                np.random.randint(low=int(0.5 * E), high=int(2.0 * E))
                for _ in range(T)
            ]
        compute_device = split_table_batched_embeddings_ops.ComputeDevice.CUDA
        if use_cpu:
            managed = [
                split_table_batched_embeddings_ops.EmbeddingLocation.HOST
            ] * T
            compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU
        else:
            managed = [
                np.random.choice([
                    split_table_batched_embeddings_ops.EmbeddingLocation.
                    DEVICE,
                    split_table_batched_embeddings_ops.EmbeddingLocation.
                    MANAGED,
                ]) for _ in range(T)
            ]
        bs = [
            to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True),
                      use_cpu) for (E, D) in zip(Es, Ds)
        ]

        xs = [
            to_device(
                torch.from_numpy(
                    np.random.choice(range(e), size=(B, L),
                                     replace=True).astype(np.int64)),
                use_cpu,
            ) for e in Es
        ]
        if long_segments and L > 0:
            for x in xs:
                x[:, 0] = 0

        xws = [to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(T)]
        xws_acc_type = copy.deepcopy(xws)

        fs = ([b_indices(b, x, use_cpu=use_cpu)
               for (b, x) in zip(bs, xs)] if not weighted else [
                   b_indices(
                       b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu)
                   for (b, x, xw) in zip(bs, xs, xws)
               ])
        gos = [torch.randn_like(f) for f in fs]
        [f.backward(go) for (f, go) in zip(fs, gos)]
        # do SGD update

        optimizer_kwargs = {"learning_rate": 0.5}
        (lr, eps, beta1, beta2, weight_decay, momentum, eta) = (
            0.5,
            1e-4,
            0.9,
            0.99,
            0.01,
            0.9,
            0.01,
        )
        if optimizer in (OptimType.EXACT_ROWWISE_ADAGRAD,
                         OptimType.EXACT_ADAGRAD):
            optimizer_kwargs["eps"] = eps

        if optimizer in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM):
            optimizer_kwargs["eps"] = eps
            optimizer_kwargs["beta1"] = beta1
            optimizer_kwargs["beta2"] = beta2
            optimizer_kwargs["weight_decay"] = weight_decay

        if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB):
            optimizer_kwargs["eps"] = eps
            optimizer_kwargs["beta1"] = beta1
            optimizer_kwargs["beta2"] = beta2
            optimizer_kwargs["weight_decay"] = weight_decay

        if optimizer == OptimType.LARS_SGD:
            optimizer_kwargs["weight_decay"] = weight_decay
            optimizer_kwargs["momentum"] = momentum
            optimizer_kwargs["eta"] = eta

        cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
            [(E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)],
            optimizer=optimizer,
            stochastic_rounding=stochastic_rounding,
            pooling_mode=pooling_mode,
            **optimizer_kwargs,
        )

        for t in range(T):
            cc.split_embedding_weights()[t].data.copy_(bs[t].weight)

        x = torch.cat([x.view(1, B, L) for x in xs], dim=0)
        xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0)

        (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu)
        fc2 = (cc(indices, offsets) if not weighted else cc(
            indices, offsets, to_device(xw.contiguous().view(-1), use_cpu)))
        fc2.backward(torch.cat([go.view(B, -1) for go in gos], dim=1))
        cc.flush()

        split_optimizer_states = cc.split_optimizer_states()
        assert len(split_optimizer_states) == T
        split_weights = cc.split_embedding_weights()

        if optimizer in (OptimType.EXACT_ROWWISE_ADAGRAD,
                         OptimType.EXACT_ADAGRAD):
            rowwise = optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
            for t in range(T):
                (m1, ) = split_optimizer_states[t]
                m1_ref = (bs[t].weight.grad.to_dense().pow(2) if not rowwise
                          else bs[t].weight.grad.to_dense().pow(2).mean(dim=1))
                torch.testing.assert_allclose(m1.float(),
                                              m1_ref.float(),
                                              atol=1.0e-4,
                                              rtol=1.0e-4)
                weights_new = split_weights[t]
                weights_ref = bs[t].weight - lr * bs[t].weight.grad.to_dense(
                ) / (torch.sqrt(m1_ref if not rowwise else m1_ref.
                                view(m1_ref.numel(), 1)) + eps)
                # TODO: why is tolerance off here?
                torch.testing.assert_allclose(weights_new.float(),
                                              weights_ref.float(),
                                              atol=1.0e-2,
                                              rtol=1.0e-2)

        if optimizer in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM):
            rowwise = optimizer == OptimType.PARTIAL_ROWWISE_ADAM
            for t in range(T):
                (m1, m2) = split_optimizer_states[t]
                m2_ref = (bs[t].weight.grad.to_dense().pow(2) if not rowwise
                          else bs[t].weight.grad.to_dense().pow(2).mean(
                              dim=1)) * (1.0 - beta2)
                torch.testing.assert_allclose(m2,
                                              m2_ref,
                                              atol=1.0e-4,
                                              rtol=1.0e-4)
                m1_ref = bs[t].weight.grad.to_dense() * (1.0 - beta1)
                torch.testing.assert_allclose(m1,
                                              m1_ref,
                                              atol=1.0e-4,
                                              rtol=1.0e-4)
                iter_ = cc.iter.item()
                v_hat_t = m2_ref / (1 - beta2**iter_)
                v_hat_t = v_hat_t if not rowwise else v_hat_t.view(
                    v_hat_t.numel(), 1)
                m_hat_t = m1_ref / (1 - beta1**iter_)
                weights_new = split_weights[t]
                weights_ref = (torch.addcdiv(
                    bs[t].weight,
                    value=-lr,
                    tensor1=m_hat_t,
                    tensor2=v_hat_t.sqrt_().add_(eps),
                ) - lr * weight_decay * bs[t].weight)
                torch.testing.assert_allclose(
                    weights_new.index_select(dim=0, index=x[t].view(-1)),
                    weights_ref.index_select(dim=0, index=x[t].view(-1)),
                    atol=1.0e-3,
                    rtol=1.0e-3,
                )

        if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB):
            rowwise = optimizer == OptimType.PARTIAL_ROWWISE_LAMB
            for t in range(T):
                (m1, m2) = split_optimizer_states[t]
                m2_ref = (bs[t].weight.grad.to_dense().pow(2) if not rowwise
                          else bs[t].weight.grad.to_dense().pow(2).mean(
                              dim=1)) * (1.0 - beta2)
                torch.testing.assert_allclose(m2,
                                              m2_ref,
                                              atol=1.0e-4,
                                              rtol=1.0e-4)
                m1_ref = bs[t].weight.grad.to_dense() * (1.0 - beta1)
                torch.testing.assert_allclose(m1,
                                              m1_ref,
                                              atol=1.0e-4,
                                              rtol=1.0e-4)
                iter_ = cc.iter.item()
                v_hat_t = m2_ref / (1 - beta2**iter_)
                v_hat_t = v_hat_t if not rowwise else v_hat_t.view(
                    v_hat_t.numel(), 1)
                m_hat_t = m1_ref / (1 - beta1**iter_)
                rtw = (
                    m_hat_t /
                    (torch.sqrt(v_hat_t) + eps)) + weight_decay * bs[t].weight
                true_ratio = torch.linalg.norm(
                    bs[t].weight, dim=1, ord=2).view(
                        m1.shape[0], 1) / torch.linalg.norm(
                            rtw, dim=1, ord=2).view(m1.shape[0], 1)
                weights_new = split_weights[t]
                weights_ref = bs[t].weight - lr * true_ratio * rtw
                torch.testing.assert_allclose(
                    weights_new.index_select(dim=0, index=x[t].view(-1)),
                    weights_ref.index_select(dim=0, index=x[t].view(-1)),
                    atol=1.0e-3,
                    rtol=1.0e-3,
                )

        if optimizer == OptimType.LARS_SGD:
            for t in range(T):
                (m1, ) = split_optimizer_states[t]
                weight_norm = torch.linalg.norm(bs[t].weight, dim=1,
                                                ord=2).view(m1.shape[0], 1)
                grad_norm = torch.linalg.norm(bs[t].weight.grad.to_dense(),
                                              dim=1,
                                              ord=2).view(m1.shape[0], 1)
                adjusted_lr = (lr * eta * weight_norm /
                               (grad_norm + weight_decay * weight_norm))
                m1_ref = adjusted_lr * (bs[t].weight.grad.to_dense() +
                                        weight_decay * bs[t].weight)

                torch.testing.assert_allclose(
                    m1.index_select(dim=0, index=x[t].view(-1)),
                    m1_ref.index_select(dim=0, index=x[t].view(-1)),
                    atol=1.0e-4,
                    rtol=1.0e-4,
                )
                weights_new = split_weights[t]
                weights_ref = bs[t].weight - m1_ref
                torch.testing.assert_allclose(
                    weights_new.index_select(dim=0, index=x[t].view(-1)),
                    weights_ref.index_select(dim=0, index=x[t].view(-1)),
                    atol=1.0e-4,
                    rtol=1.0e-4,
                )
示例#7
0
    def test_backward_adagrad(  # noqa C901
        self,
        T,
        D,
        B,
        log_E,
        L,
        D_gradcheck,
        weights_precision,
        stochastic_rounding,
        weighted,
        row_wise,
        mixed,
        use_cache,
        cache_algorithm,
        pooling_mode,
        use_cpu,
    ):
        # NOTE: cache is not applicable to CPU version.
        assume(not use_cpu or not use_cache)
        exact = True  # Only exact sparse optimizers are supported

        # NOTE: torch.autograd.gradcheck() is too time-consuming for CPU version
        #       so we have to limit (T * B * L * D)!
        assume(not use_cpu or T * B * L * D <= 1024)
        assume(not (use_cpu and weights_precision == SparseType.FP16))

        assume(
            pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM
            or not weighted)
        mode = ("sum" if pooling_mode
                == split_table_batched_embeddings_ops.PoolingMode.SUM else
                "mean")

        # stochastic rounding only implemented for rowwise
        assume(not stochastic_rounding or row_wise)
        # need unique indices for non-exact tests
        assume(exact or int(10**log_E) > int(2.1 * B * L))
        # only row-wise supports caching
        assume(row_wise or not use_cache)

        E = int(10**log_E)
        if use_cpu:
            D = (D + 15) // 16 * 4
        else:
            D = D * 4
        if not mixed:
            Ds = [D] * T
            Es = [E] * T
        else:
            Ds = [
                div_round_up(
                    np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
                for _ in range(T)
            ]
            Es = [
                np.random.randint(low=int(0.5 * E), high=int(2.0 * E))
                for _ in range(T)
            ]
        compute_device = split_table_batched_embeddings_ops.ComputeDevice.CUDA
        if use_cpu:
            managed = [
                split_table_batched_embeddings_ops.EmbeddingLocation.HOST
            ] * T
            compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU
        elif use_cache:
            managed = [
                split_table_batched_embeddings_ops.EmbeddingLocation.
                MANAGED_CACHING
            ] * T
            if mixed:
                average_D = sum(Ds) // T
                for t, d in enumerate(Ds):
                    managed[t] = (split_table_batched_embeddings_ops.
                                  EmbeddingLocation.DEVICE
                                  if d < average_D else managed[t])
        else:
            managed = [
                np.random.choice([
                    split_table_batched_embeddings_ops.EmbeddingLocation.
                    DEVICE,
                    split_table_batched_embeddings_ops.EmbeddingLocation.
                    MANAGED,
                ]) for _ in range(T)
            ]
        bs = [
            to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True),
                      use_cpu) for (E, D) in zip(Es, Ds)
        ]

        if weights_precision == SparseType.FP16 and not use_cpu:
            # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16.
            bs = [b.half() for b in bs]

        feature_table_map = list(range(T))
        if exact:
            # autograd with shared embedding only works for exact
            table_to_replicate = T // 2
            bs.insert(table_to_replicate, bs[table_to_replicate])
            feature_table_map.insert(table_to_replicate, table_to_replicate)

        xs = [
            to_device(
                torch.from_numpy(
                    np.random.choice(range(Es[t]), size=(B, L),
                                     replace=exact).astype(np.int64)),
                use_cpu,
            ) for t in feature_table_map
        ]
        xws = [
            to_device(torch.randn(size=(B, L)), use_cpu)
            for _ in range(len(xs))
        ]
        xws_acc_type = copy.deepcopy(xws)

        if weights_precision == SparseType.FP16 and not use_cpu:
            # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16.
            xws = [xw.half() for xw in xws]

        fs = ([b_indices(b, x, use_cpu=use_cpu)
               for (b, x) in zip(bs, xs)] if not weighted else [
                   b_indices(
                       b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu)
                   for (b, x, xw) in zip(bs, xs, xws)
               ])
        gos = [torch.randn_like(f) for f in fs]
        [f.backward(go) for (f, go) in zip(fs, gos)]
        # do SGD update
        lr = 0.5
        eps = 0.2

        cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
            [(E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)],
            feature_table_map=feature_table_map,
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD
            if row_wise else OptimType.EXACT_ADAGRAD,
            learning_rate=lr,
            eps=eps,
            weights_precision=weights_precision,
            stochastic_rounding=stochastic_rounding,
            pooling_mode=pooling_mode,
        )

        if exact:
            del bs[table_to_replicate]
        for t in range(T):
            cc.split_embedding_weights()[t].data.copy_(bs[t].weight)

        x = torch.cat([x.view(1, B, L) for x in xs], dim=0)
        xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0)

        (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu)
        fc2 = (cc(indices, offsets) if not weighted else cc(
            indices, offsets, to_device(xw.contiguous().view(-1), use_cpu)))
        fc2.backward(torch.cat([go.view(B, -1) for go in gos], dim=1))
        cc.flush()

        split_optimizer_states = [s for (s, ) in cc.split_optimizer_states()]
        for t in range(T):
            ref_optimizer_state = bs[t].weight.grad.float().to_dense().pow(2)
            torch.testing.assert_allclose(
                split_optimizer_states[t].float(),
                ref_optimizer_state.mean(
                    dim=1) if row_wise else ref_optimizer_state,
                atol=5.0e-3
                if weights_precision == SparseType.FP16 else 1.0e-4,
                rtol=5.0e-3
                if weights_precision == SparseType.FP16 else 1.0e-4,
            )

        for t in range(T):
            # optimizer_state = squares (no row-wise) or sum squares (row-wise)
            torch.testing.assert_allclose(
                cc.split_embedding_weights()[t].float(),
                torch.addcdiv(
                    bs[t].weight.float(),
                    value=-lr,
                    tensor1=bs[t].weight.grad.float().to_dense(),
                    tensor2=split_optimizer_states[t].float().sqrt_().add_(
                        eps).view(Es[t], 1 if row_wise else Ds[t]),
                ),
                atol=5.0e-3
                if weights_precision == SparseType.FP16 else 1.0e-4,
                rtol=5.0e-3
                if weights_precision == SparseType.FP16 else 1.0e-4,
            )
        if use_cpu:
            D_gradcheck = (D_gradcheck + 15) // 16 * 4
        else:
            D_gradcheck = D_gradcheck * 4
        cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
            [(E, D_gradcheck, M, compute_device)
             for (E, M) in zip(Es, managed)],
            feature_table_map=feature_table_map,
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD
            if row_wise else OptimType.EXACT_ADAGRAD,
            learning_rate=0.0,
            eps=eps,
            weights_precision=weights_precision,
            stochastic_rounding=stochastic_rounding,
            # NOTE: only SUM pooling can work with per_sample_weights!
            pooling_mode=split_table_batched_embeddings_ops.PoolingMode.SUM,
        )
        if use_cpu:
            # NOTE: GPU version of SplitTableBatchedEmbeddingBagsCodegen doesn't support double.
            cc = cc.double()

        per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu)
        if use_cpu:
            per_sample_weights = per_sample_weights.double()
        per_sample_weights.requires_grad = True
        indices.requires_grad = False
        offsets.requires_grad = False
        for param in cc.parameters():
            param.requires_grad = False
        torch.autograd.gradcheck(cc, (indices, offsets, per_sample_weights))

        per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu)
        if use_cpu:
            per_sample_weights = per_sample_weights.double()
        per_sample_weights.requires_grad = True
        indices.requires_grad = False
        offsets.requires_grad = False
        for param in cc.parameters():
            param.requires_grad = False
        y = cc(indices, offsets, per_sample_weights)
        y.sum().backward()
        indice_weight_grad_all = per_sample_weights.grad.clone().cpu()
        T_ = len(xws)
        feature_requires_grad = to_device(
            torch.tensor(np.random.choice([0, 1], replace=True,
                                          size=(T_, ))).int(),
            use_cpu,
        )
        per_sample_weights = per_sample_weights.detach().clone()
        per_sample_weights.requires_grad = True
        y = cc(
            indices,
            offsets,
            per_sample_weights,
            feature_requires_grad=feature_requires_grad,
        )
        y.sum().backward()
        indice_weight_grad_mask = per_sample_weights.grad.clone().cpu()
        for t in range(T_):
            if feature_requires_grad[t]:
                torch.testing.assert_allclose(
                    indice_weight_grad_mask.view(T_, B, L)[t],
                    indice_weight_grad_all.view(T_, B, L)[t],
                )
            else:
                torch.testing.assert_allclose(
                    indice_weight_grad_mask.view(T_, B, L)[t],
                    torch.zeros_like(
                        indice_weight_grad_mask.view(T_, B, L)[t]),
                )
示例#8
0
    def test_backward_sgd(  # noqa C901
        self,
        T,
        D,
        B,
        log_E,
        L,
        weights_precision,
        weighted,
        mixed,
        use_cache,
        cache_algorithm,
        long_segments,
        pooling_mode,
        use_cpu,
    ):
        # NOTE: cache is not applicable to CPU version.
        assume(not use_cpu or not use_cache)
        # NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
        assume(not use_cpu or T * B * L * D <= 2048)
        assume(not (use_cpu and weights_precision == SparseType.FP16))

        assume(
            pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM
            or not weighted)
        mode = ("sum" if pooling_mode
                == split_table_batched_embeddings_ops.PoolingMode.SUM else
                "mean")

        E = int(10**log_E)
        if use_cpu:
            D = (D + 15) // 16 * 4
        else:
            D = D * 4
        if not mixed:
            Ds = [D] * T
            Es = [E] * T
        else:
            Ds = [
                div_round_up(
                    np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
                for _ in range(T)
            ]
            Es = [
                np.random.randint(low=int(0.5 * E), high=int(2.0 * E))
                for _ in range(T)
            ]
        compute_device = split_table_batched_embeddings_ops.ComputeDevice.CUDA
        if use_cpu:
            managed = [
                split_table_batched_embeddings_ops.EmbeddingLocation.HOST
            ] * T
            compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU
        elif use_cache:
            managed = [
                split_table_batched_embeddings_ops.EmbeddingLocation.
                MANAGED_CACHING
            ] * T
            if mixed:
                average_D = sum(Ds) // T
                for t, d in enumerate(Ds):
                    managed[t] = (split_table_batched_embeddings_ops.
                                  EmbeddingLocation.DEVICE
                                  if d < average_D else managed[t])
        else:
            managed = [
                np.random.choice([
                    split_table_batched_embeddings_ops.EmbeddingLocation.
                    DEVICE,
                    split_table_batched_embeddings_ops.EmbeddingLocation.
                    MANAGED,
                ]) for _ in range(T)
            ]
        bs = [
            to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True),
                      use_cpu) for (E, D) in zip(Es, Ds)
        ]

        if weights_precision == SparseType.FP16 and not use_cpu:
            # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16.
            bs = [b.half() for b in bs]

        feature_table_map = list(range(T))
        table_to_replicate = T // 2
        bs.insert(table_to_replicate, bs[table_to_replicate])
        feature_table_map.insert(table_to_replicate, table_to_replicate)

        xs = [
            to_device(
                torch.from_numpy(
                    np.random.choice(range(Es[t]), size=(B, L),
                                     replace=True).astype(np.int64)),
                use_cpu,
            ) for t in feature_table_map
        ]

        if long_segments and L > 0:
            for x in xs:
                x[:, 0] = 0

        xws = [
            to_device(torch.randn(size=(B, L)), use_cpu)
            for _ in range(len(xs))
        ]
        xws_acc_type = copy.deepcopy(xws)

        if weights_precision == SparseType.FP16 and not use_cpu:
            # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16.
            xws = [xw.half() for xw in xws]

        fs = ([b_indices(b, x, use_cpu=use_cpu)
               for (b, x) in zip(bs, xs)] if not weighted else [
                   b_indices(
                       b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu)
                   for (b, x, xw) in zip(bs, xs, xws)
               ])
        gos = [torch.randn_like(f) for f in fs]
        [f.backward(go) for (f, go) in zip(fs, gos)]
        # do SGD update
        lr = 0.05
        del bs[table_to_replicate]
        new_weights = [(b.weight - b.weight.grad * lr) for b in bs]

        cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
            [(E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)],
            optimizer=OptimType.EXACT_SGD,
            feature_table_map=feature_table_map,
            learning_rate=0.05,
            weights_precision=weights_precision,
            cache_algorithm=cache_algorithm,
            pooling_mode=pooling_mode,
        )

        for t in range(T):
            cc.split_embedding_weights()[t].data.copy_(bs[t].weight)

        x = torch.cat([x.view(1, B, L) for x in xs], dim=0)
        xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0)

        (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu)
        fc2 = (cc(indices, offsets) if not weighted else cc(
            indices, offsets, to_device(xw.contiguous().view(-1), use_cpu)))
        goc = torch.cat([go.view(B, -1) for go in gos], dim=1).contiguous()
        fc2.backward(goc)
        if use_cache:
            cc.flush()
        for t in range(T):
            torch.testing.assert_allclose(
                cc.split_embedding_weights()[t],
                new_weights[t].half() if weights_precision == SparseType.FP16
                and not use_cpu else new_weights[t],
                atol=(1.0e-2 if long_segments else 5.0e-3)
                if weights_precision == SparseType.FP16 else 1.0e-5,
                rtol=2.0e-2
                if weights_precision == SparseType.FP16 else 1.0e-5,
            )
示例#9
0
    def test_forward(
        self,
        T,
        D,
        B,
        log_E,
        L,
        weights_precision,
        weighted,
        mixed,
        use_cache,
        cache_algorithm,
        pooling_mode,
        use_cpu,
    ):
        # NOTE: cache is not applicable to CPU version.
        assume(not use_cpu or not use_cache)
        # NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
        assume(not use_cpu or T * B * L * D <= 2048)
        assume(not (use_cpu and weights_precision == SparseType.FP16))

        assume(
            pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM
            or not weighted)
        mode = ("sum" if pooling_mode
                == split_table_batched_embeddings_ops.PoolingMode.SUM else
                "mean")

        E = int(10**log_E)
        if use_cpu:
            D = (D + 15) // 16 * 4
        else:
            D = D * 4
        if not mixed:
            Ds = [D] * T
            Es = [int(1e4)] * T
        else:
            Ds = [
                div_round_up(
                    np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
                for _ in range(T)
            ]
            Es = [
                np.random.randint(low=int(0.5 * E), high=int(2.0 * E))
                for _ in range(T)
            ]
        compute_device = split_table_batched_embeddings_ops.ComputeDevice.CUDA
        if use_cpu:
            managed = [
                split_table_batched_embeddings_ops.EmbeddingLocation.HOST
            ] * T
            compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU
        elif use_cache:
            managed = [
                split_table_batched_embeddings_ops.EmbeddingLocation.
                MANAGED_CACHING
            ] * T
            if mixed:
                average_D = sum(Ds) // T
                for t, d in enumerate(Ds):
                    managed[t] = (split_table_batched_embeddings_ops.
                                  EmbeddingLocation.DEVICE
                                  if d < average_D else managed[t])
        else:
            managed = [
                np.random.choice([
                    split_table_batched_embeddings_ops.EmbeddingLocation.
                    DEVICE,
                    split_table_batched_embeddings_ops.EmbeddingLocation.
                    MANAGED,
                ]) for _ in range(T)
            ]
        bs = [
            to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True),
                      use_cpu) for (E, D) in zip(Es, Ds)
        ]
        if weights_precision == SparseType.FP16 and not use_cpu:
            # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16.
            bs = [b.half() for b in bs]

        xs = [
            to_device(torch.randint(low=0, high=e, size=(B, L)), use_cpu)
            for e in Es
        ]
        xws = [to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(T)]
        xws_acc_type = copy.deepcopy(xws)

        if weights_precision == SparseType.FP16 and not use_cpu:
            # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16.
            xws = [xw.half() for xw in xws]

        fs = ([b_indices(b, x, use_cpu=use_cpu)
               for (b, x) in zip(bs, xs)] if not weighted else [
                   b_indices(
                       b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu)
                   for (b, x, xw) in zip(bs, xs, xws)
               ])
        f = torch.cat([f.view(B, -1) for f in fs], dim=1)

        cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[(
                E,
                D,
                split_table_batched_embeddings_ops.EmbeddingLocation(M),
                compute_device,
            ) for (E, D, M) in zip(Es, Ds, managed)],
            weights_precision=weights_precision,
            optimizer=OptimType.EXACT_SGD,
            learning_rate=0.05,
            cache_algorithm=cache_algorithm,
            pooling_mode=pooling_mode,
        )
        # NOTE: test TorchScript-compatible!
        cc = torch.jit.script(cc)

        for t in range(T):
            cc.split_embedding_weights()[t].data.copy_(bs[t].weight)

        x = torch.cat([x.view(1, B, L) for x in xs], dim=0)
        xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0)

        (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu)
        fc2 = (cc(indices, offsets) if not weighted else cc(
            indices, offsets, to_device(xw.contiguous().view(-1), use_cpu)))
        torch.testing.assert_allclose(
            fc2.float(),
            f.float(),
            atol=8.0e-3 if weights_precision == SparseType.FP16 else 1.0e-5,
            rtol=8.0e-3 if weights_precision == SparseType.FP16 else 1.0e-5,
        )