def test_cache_pipeline(self, T, D, B, log_E, L, mixed, cache_algorithm): iters = 3 E = int(10**log_E) D = D * 4 if not mixed: Ds = [D] * T Es = [E] * T else: Ds = [ div_round_up( np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] Es = [ np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) ] managed = [ split_table_batched_embeddings_ops.EmbeddingLocation. MANAGED_CACHING ] * T if mixed: average_D = sum(Ds) // T for t, d in enumerate(Ds): managed[t] = ( split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE if d < average_D else managed[t]) cc_ref = ( split_table_batched_embeddings_ops. SplitTableBatchedEmbeddingBagsCodegen([( E, D, split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE, split_table_batched_embeddings_ops.ComputeDevice.CUDA, ) for (E, D) in zip(Es, Ds)], )) cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen( [(E, D, M, split_table_batched_embeddings_ops.ComputeDevice.CUDA) for (E, D, M) in zip(Es, Ds, managed)], cache_algorithm=cache_algorithm, ) for t in range(T): assert (cc.split_embedding_weights()[t].size() == cc_ref.split_embedding_weights()[t].size()) cc.split_embedding_weights()[t].data.copy_( cc_ref.split_embedding_weights()[t]) requests = generate_requests(iters, B, T, L, min(Es), reuse=0.1) grad_output = torch.randn(B, sum(Ds)).cuda() for indices, offsets, _ in requests: output = cc(indices, offsets) output_ref = cc_ref(indices, offsets) torch.testing.assert_allclose(output, output_ref) output.backward(grad_output) output_ref.backward(grad_output) cc.flush() for t in range(T): torch.testing.assert_allclose(cc.split_embedding_weights()[t], cc_ref.split_embedding_weights()[t])
def __init__( self, emb_dim, num_tables, num_rows, use_cpu, ) -> None: super().__init__() pooling_mode = split_table_batched_embeddings_ops.PoolingMode.SUM Ds = [emb_dim] * num_tables Es = [num_rows] * num_tables device = ( split_table_batched_embeddings_ops.ComputeDevice.CPU if use_cpu else split_table_batched_embeddings_ops.ComputeDevice.CUDA ) loc = ( split_table_batched_embeddings_ops.EmbeddingLocation.HOST if use_cpu else split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE ) self.emb_module = ( split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen( embedding_specs=[ ( E, D, loc, device, ) for (E, D) in zip(Es, Ds) ], weights_precision=SparseType.FP32, optimizer=OptimType.EXACT_SGD, learning_rate=0.05, pooling_mode=pooling_mode, ) ) self.emb_module.init_embedding_weights_uniform( -EMB_WEIGHT_UNIFORM_INIT_BOUND, +EMB_WEIGHT_UNIFORM_INIT_BOUND )
def cache( # noqa C901 alpha: float, bag_size: int, batch_size: int, cache_algorithm: str, cache_sets: int, embedding_dim: int, weights_precision: SparseType, stoc: bool, iters: int, long_index: bool, mixed: bool, num_embeddings: int, num_tables: int, reuse: float, weighted: bool, ) -> None: np.random.seed(42) optimizer = OptimType.EXACT_ROWWISE_ADAGRAD B = batch_size D = embedding_dim L = bag_size E = num_embeddings T = num_tables cache_alg = (split_table_batched_embeddings_ops.CacheAlgorithm.LRU if cache_algorithm == "lru" else split_table_batched_embeddings_ops.CacheAlgorithm.LFU) if mixed: Ds = [ div_round_up( np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] D = np.average(Ds) else: Ds = [D] * T emb_nc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen( [( E, d, split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED, split_table_batched_embeddings_ops.ComputeDevice.CUDA, ) for d in Ds], optimizer=optimizer, weights_precision=weights_precision, stochastic_rounding=stoc, ).cuda() if weights_precision == SparseType.INT8: emb_nc.init_embedding_weights_uniform(-0.0003, 0.0003) emb = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen( [( E, d, split_table_batched_embeddings_ops.EmbeddingLocation. MANAGED_CACHING, split_table_batched_embeddings_ops.ComputeDevice.CUDA, ) for d in Ds], optimizer=optimizer, weights_precision=weights_precision, stochastic_rounding=stoc, cache_sets=cache_sets, cache_algorithm=cache_alg, ).cuda() if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = PRECISION_SIZE_MULTIPLIER[weights_precision] logging.info( f"Embedding tables: {E * T} rows, {nparams / 1.0e9: .2f} GParam, " f"{nparams * param_size_multiplier / 1.0e6: .2f}MB") logging.info(f"Accessed weights per batch: {B * T * L} rows, " f"{B * T * L * D * param_size_multiplier / 1.0e6: .2f}MB") requests = generate_requests(2 * iters, B, T, L, E, reuse=reuse, alpha=alpha, weighted=weighted) warmup_requests, requests = requests[:iters], requests[iters:] grad_output = torch.randn(B, sum(Ds)).cuda() time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb_nc( indices.long(), offsets.long(), per_sample_weights).backward( grad_output), ) logging.info( f"ForwardBackward (UVM), B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, " f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f}GB/s, " f"T: {time_per_iter * 1.0e6:.0f}us") # warm up for indices, offsets, _ in warmup_requests: emb.forward(indices.long(), offsets.long()) # get cache miss rate (forward and backward) and exchanged cache lines (prefetch) cache_misses = [] exchanged_cache_lines = [] NOT_FOUND = -1 for indices, offsets, _ in requests: # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no attribute # `lxu_cache_state`. old_lxu_cache_state = emb.lxu_cache_state.clone() emb.prefetch(indices.long(), offsets.long()) exchanged_cache_lines.append( (emb.lxu_cache_state != old_lxu_cache_state).sum().item()) cache_misses.append( (emb.lxu_cache_locations_list[0] == NOT_FOUND).sum().item()) emb.forward(indices.long(), offsets.long()) logging.info( f"Exchanged cache lines -- mean: {sum(exchanged_cache_lines)/len(requests): .2f}, " f"max: {max(exchanged_cache_lines)}, min: {min(exchanged_cache_lines)}" ) logging.info(f"Cache miss -- mean: {sum(cache_misses)/len(requests)}, " f"max: {max(cache_misses)}, min: {min(cache_misses)}") # benchmark prefetch emb.reset_cache_states() for indices, offsets, _ in warmup_requests: emb.forward(indices, offsets) prefetch_time, forward_backward_time = benchmark_pipelined_requests( requests, lambda indices, offsets, indices_weights: emb.prefetch( indices, offsets), lambda indices, offsets, indices_weights: emb.forward( indices, offsets, indices_weights).backward(grad_output), ) e2e_time = prefetch_time + forward_backward_time logging.info( f"ForwardBackward (LXU), reuse: {reuse}, alpha: {alpha}, B: {B}, " f"E: {E}, T: {T}, D: {D}, L: {L}, " f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / e2e_time / 1.0e9: .2f}GB/s, " f"Tprefetch: {prefetch_time * 1.0e6:.0f}us, " f"{2 * sum(exchanged_cache_lines) * param_size_multiplier * D / prefetch_time / len(requests) / 1.0e9: .2f} GB/s, " f"Tfwdbwd: {forward_backward_time * 1.0e6:.0f}us, " f"{3 * param_size_multiplier * B * sum(Ds) * L / forward_backward_time / 1.0e9: .2f} GB/s, " f"Te2e: {e2e_time * 1.0e6:.0f}us, ")
def uvm( alpha: bool, bag_size: int, batch_size: int, embedding_dim: int, weights_precision: SparseType, stoc: bool, iters: int, mixed: bool, num_embeddings: int, num_tables: int, reuse: float, uvm_tables: int, uvm_bag_size: int, weighted: bool, ) -> None: np.random.seed(42) B = batch_size D = embedding_dim L = bag_size E = num_embeddings T = num_tables T_uvm = uvm_tables assert T_uvm <= T T_gpu = T - T_uvm L_uvm = uvm_bag_size if mixed: Ds = [ div_round_up( np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] D = np.average(Ds) else: Ds = [D] * T emb_uvm = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen( [( E, d, split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED, split_table_batched_embeddings_ops.ComputeDevice.CUDA, ) for d in Ds[:T_uvm]], weights_precision=weights_precision, stochastic_rounding=stoc, ).cuda() if weights_precision == SparseType.INT8: emb_uvm.init_embedding_weights_uniform(-0.0003, 0.0003) emb_gpu = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen( [( E, d, split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE, split_table_batched_embeddings_ops.ComputeDevice.CUDA, ) for d in Ds[T_uvm:]], weights_precision=weights_precision, stochastic_rounding=stoc, ).cuda() if weights_precision == SparseType.INT8: emb_gpu.init_embedding_weights_uniform(-0.0003, 0.0003) emb_mixed = ( split_table_batched_embeddings_ops. SplitTableBatchedEmbeddingBagsCodegen( [( E, d, managed_option, split_table_batched_embeddings_ops.ComputeDevice.CUDA, ) for (d, managed_option) in zip( Ds, [split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED ] * T_uvm + [split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE] * T_gpu, )], weights_precision=weights_precision, stochastic_rounding=stoc, ).cuda()) if weights_precision == SparseType.INT8: emb_mixed.init_embedding_weights_uniform(-0.0003, 0.0003) requests_uvm = generate_requests( iters, B, T_uvm, L_uvm, E, reuse=reuse, alpha=alpha, weights_precision=weights_precision, weighted=weighted, ) requests_gpu = generate_requests( iters, B, T_gpu, L, E, reuse=reuse, alpha=alpha, weights_precision=weights_precision, weighted=False, ) requests = [] for rs_uvm, rs_gpu in zip(requests_uvm, requests_gpu): indices = torch.cat([rs_uvm[0], rs_gpu[0]]) lengths = [L_uvm] * (T_uvm * B) + [L] * (T_gpu * B) offsets = torch.tensor( ([0] + np.cumsum(lengths).tolist())).int().cuda() per_sample_weights = None if weighted: assert (this_rs_uvm_weights := rs_uvm[2]) is not None assert (this_rs_gpu_weights := rs_gpu[2]) is not None per_sample_weights = torch.cat( [this_rs_uvm_weights, this_rs_gpu_weights]) requests.append((indices, offsets, per_sample_weights)) # forward time_per_iter = benchmark_requests( requests_gpu, lambda indices, offsets, per_sample_weights: emb_gpu.forward( indices.long(), offsets.long(), per_sample_weights, ), ) param_size_multiplier = PRECISION_SIZE_MULTIPLIER[weights_precision] logging.info( f"GPU Forward, B: {B}, " f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, " f"BW: {param_size_multiplier * B * T_gpu * L * D / time_per_iter / 1.0e9: .2f}GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us") time_per_iter = benchmark_requests( requests_uvm, lambda indices, offsets, per_sample_weights: emb_uvm.forward( indices.long(), offsets.long(), per_sample_weights, ), ) logging.info( f"UVM Forward, B: {B}, " f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, " f"BW: {param_size_multiplier * B * T_gpu * L * D / time_per_iter / 1.0e9: .2f}GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us") time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb_mixed.forward( indices.long(), offsets.long(), per_sample_weights, ), ) logging.info( f"Mixed Forward, B: {B}, " f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, " f"BW: {param_size_multiplier * B * T_gpu * L * D / time_per_iter / 1.0e9: .2f}GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us")
def device( # noqa C901 alpha: float, bag_size: int, batch_size: int, embedding_dim: int, weights_precision: SparseType, stoc: bool, iters: int, managed: str, mixed: bool, num_embeddings: int, num_tables: int, reuse: float, row_wise: bool, weighted: bool, weighted_num_requires_grad: Optional[int], ) -> None: np.random.seed(42) B = batch_size D = embedding_dim L = bag_size E = num_embeddings T = num_tables if weighted_num_requires_grad: assert weighted_num_requires_grad <= T weighted_requires_grad_tables = np.random.choice( T, replace=False, size=(weighted_num_requires_grad, )).tolist() feature_requires_grad = (torch.tensor([ 1 if t in weighted_requires_grad_tables else 0 for t in range(T) ]).cuda().int()) else: feature_requires_grad = None if mixed: Ds = [ div_round_up( np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] D = np.average(Ds) else: Ds = [D] * T optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD if managed == "device": managed_option = split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE else: managed_option = split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED emb = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen( [( E, d, managed_option, split_table_batched_embeddings_ops.ComputeDevice.CUDA, ) for d in Ds], optimizer=optimizer, learning_rate=0.1, eps=0.1, weights_precision=weights_precision, stochastic_rounding=stoc, ).cuda() if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = PRECISION_SIZE_MULTIPLIER[weights_precision] logging.info(f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, " f"{nparams * param_size_multiplier / 1.0e9: .2f}GB") logging.info( f"Accessed weights per batch: {B * T * L * D * param_size_multiplier / 1.0e6: .2f}MB" ) requests = generate_requests( iters, B, T, L, E, reuse=reuse, alpha=alpha, weights_precision=weights_precision, weighted=weighted, ) # forward time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb.forward( indices.long(), offsets.long(), per_sample_weights, feature_requires_grad=feature_requires_grad, ), ) logging.info( f"Forward, B: {B}, " f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " f"BW: {param_size_multiplier * B * T * L * D / time_per_iter / 1.0e9: .2f}GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us") grad_output = torch.randn(B, sum(Ds)).cuda() # backward time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb( indices.long(), offsets.long(), per_sample_weights, feature_requires_grad=feature_requires_grad, ).backward(grad_output), ) logging.info( f"ForwardBackward, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, " f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f}GB/s, " f"T: {time_per_iter * 1.0e6:.0f}us")
def test_backward_optimizers( # noqa C901 self, T, D, B, log_E, L, stochastic_rounding, weighted, mixed, optimizer, long_segments, pooling_mode, use_cpu, ): # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! assume(not use_cpu or T * B * L * D <= 2048) assume( pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM or not weighted) mode = ("sum" if pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM else "mean") E = int(10**log_E) if use_cpu: D = (D + 15) // 16 * 4 else: D = D * 4 if not mixed: Ds = [D] * T Es = [E] * T else: Ds = [ div_round_up( np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] Es = [ np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) ] compute_device = split_table_batched_embeddings_ops.ComputeDevice.CUDA if use_cpu: managed = [ split_table_batched_embeddings_ops.EmbeddingLocation.HOST ] * T compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU else: managed = [ np.random.choice([ split_table_batched_embeddings_ops.EmbeddingLocation. DEVICE, split_table_batched_embeddings_ops.EmbeddingLocation. MANAGED, ]) for _ in range(T) ] bs = [ to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu) for (E, D) in zip(Es, Ds) ] xs = [ to_device( torch.from_numpy( np.random.choice(range(e), size=(B, L), replace=True).astype(np.int64)), use_cpu, ) for e in Es ] if long_segments and L > 0: for x in xs: x[:, 0] = 0 xws = [to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(T)] xws_acc_type = copy.deepcopy(xws) fs = ([b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)] if not weighted else [ b_indices( b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu) for (b, x, xw) in zip(bs, xs, xws) ]) gos = [torch.randn_like(f) for f in fs] [f.backward(go) for (f, go) in zip(fs, gos)] # do SGD update optimizer_kwargs = {"learning_rate": 0.5} (lr, eps, beta1, beta2, weight_decay, momentum, eta) = ( 0.5, 1e-4, 0.9, 0.99, 0.01, 0.9, 0.01, ) if optimizer in (OptimType.EXACT_ROWWISE_ADAGRAD, OptimType.EXACT_ADAGRAD): optimizer_kwargs["eps"] = eps if optimizer in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM): optimizer_kwargs["eps"] = eps optimizer_kwargs["beta1"] = beta1 optimizer_kwargs["beta2"] = beta2 optimizer_kwargs["weight_decay"] = weight_decay if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB): optimizer_kwargs["eps"] = eps optimizer_kwargs["beta1"] = beta1 optimizer_kwargs["beta2"] = beta2 optimizer_kwargs["weight_decay"] = weight_decay if optimizer == OptimType.LARS_SGD: optimizer_kwargs["weight_decay"] = weight_decay optimizer_kwargs["momentum"] = momentum optimizer_kwargs["eta"] = eta cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen( [(E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)], optimizer=optimizer, stochastic_rounding=stochastic_rounding, pooling_mode=pooling_mode, **optimizer_kwargs, ) for t in range(T): cc.split_embedding_weights()[t].data.copy_(bs[t].weight) x = torch.cat([x.view(1, B, L) for x in xs], dim=0) xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0) (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu) fc2 = (cc(indices, offsets) if not weighted else cc( indices, offsets, to_device(xw.contiguous().view(-1), use_cpu))) fc2.backward(torch.cat([go.view(B, -1) for go in gos], dim=1)) cc.flush() split_optimizer_states = cc.split_optimizer_states() assert len(split_optimizer_states) == T split_weights = cc.split_embedding_weights() if optimizer in (OptimType.EXACT_ROWWISE_ADAGRAD, OptimType.EXACT_ADAGRAD): rowwise = optimizer == OptimType.EXACT_ROWWISE_ADAGRAD for t in range(T): (m1, ) = split_optimizer_states[t] m1_ref = (bs[t].weight.grad.to_dense().pow(2) if not rowwise else bs[t].weight.grad.to_dense().pow(2).mean(dim=1)) torch.testing.assert_allclose(m1.float(), m1_ref.float(), atol=1.0e-4, rtol=1.0e-4) weights_new = split_weights[t] weights_ref = bs[t].weight - lr * bs[t].weight.grad.to_dense( ) / (torch.sqrt(m1_ref if not rowwise else m1_ref. view(m1_ref.numel(), 1)) + eps) # TODO: why is tolerance off here? torch.testing.assert_allclose(weights_new.float(), weights_ref.float(), atol=1.0e-2, rtol=1.0e-2) if optimizer in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM): rowwise = optimizer == OptimType.PARTIAL_ROWWISE_ADAM for t in range(T): (m1, m2) = split_optimizer_states[t] m2_ref = (bs[t].weight.grad.to_dense().pow(2) if not rowwise else bs[t].weight.grad.to_dense().pow(2).mean( dim=1)) * (1.0 - beta2) torch.testing.assert_allclose(m2, m2_ref, atol=1.0e-4, rtol=1.0e-4) m1_ref = bs[t].weight.grad.to_dense() * (1.0 - beta1) torch.testing.assert_allclose(m1, m1_ref, atol=1.0e-4, rtol=1.0e-4) iter_ = cc.iter.item() v_hat_t = m2_ref / (1 - beta2**iter_) v_hat_t = v_hat_t if not rowwise else v_hat_t.view( v_hat_t.numel(), 1) m_hat_t = m1_ref / (1 - beta1**iter_) weights_new = split_weights[t] weights_ref = (torch.addcdiv( bs[t].weight, value=-lr, tensor1=m_hat_t, tensor2=v_hat_t.sqrt_().add_(eps), ) - lr * weight_decay * bs[t].weight) torch.testing.assert_allclose( weights_new.index_select(dim=0, index=x[t].view(-1)), weights_ref.index_select(dim=0, index=x[t].view(-1)), atol=1.0e-3, rtol=1.0e-3, ) if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB): rowwise = optimizer == OptimType.PARTIAL_ROWWISE_LAMB for t in range(T): (m1, m2) = split_optimizer_states[t] m2_ref = (bs[t].weight.grad.to_dense().pow(2) if not rowwise else bs[t].weight.grad.to_dense().pow(2).mean( dim=1)) * (1.0 - beta2) torch.testing.assert_allclose(m2, m2_ref, atol=1.0e-4, rtol=1.0e-4) m1_ref = bs[t].weight.grad.to_dense() * (1.0 - beta1) torch.testing.assert_allclose(m1, m1_ref, atol=1.0e-4, rtol=1.0e-4) iter_ = cc.iter.item() v_hat_t = m2_ref / (1 - beta2**iter_) v_hat_t = v_hat_t if not rowwise else v_hat_t.view( v_hat_t.numel(), 1) m_hat_t = m1_ref / (1 - beta1**iter_) rtw = ( m_hat_t / (torch.sqrt(v_hat_t) + eps)) + weight_decay * bs[t].weight true_ratio = torch.linalg.norm( bs[t].weight, dim=1, ord=2).view( m1.shape[0], 1) / torch.linalg.norm( rtw, dim=1, ord=2).view(m1.shape[0], 1) weights_new = split_weights[t] weights_ref = bs[t].weight - lr * true_ratio * rtw torch.testing.assert_allclose( weights_new.index_select(dim=0, index=x[t].view(-1)), weights_ref.index_select(dim=0, index=x[t].view(-1)), atol=1.0e-3, rtol=1.0e-3, ) if optimizer == OptimType.LARS_SGD: for t in range(T): (m1, ) = split_optimizer_states[t] weight_norm = torch.linalg.norm(bs[t].weight, dim=1, ord=2).view(m1.shape[0], 1) grad_norm = torch.linalg.norm(bs[t].weight.grad.to_dense(), dim=1, ord=2).view(m1.shape[0], 1) adjusted_lr = (lr * eta * weight_norm / (grad_norm + weight_decay * weight_norm)) m1_ref = adjusted_lr * (bs[t].weight.grad.to_dense() + weight_decay * bs[t].weight) torch.testing.assert_allclose( m1.index_select(dim=0, index=x[t].view(-1)), m1_ref.index_select(dim=0, index=x[t].view(-1)), atol=1.0e-4, rtol=1.0e-4, ) weights_new = split_weights[t] weights_ref = bs[t].weight - m1_ref torch.testing.assert_allclose( weights_new.index_select(dim=0, index=x[t].view(-1)), weights_ref.index_select(dim=0, index=x[t].view(-1)), atol=1.0e-4, rtol=1.0e-4, )
def test_backward_adagrad( # noqa C901 self, T, D, B, log_E, L, D_gradcheck, weights_precision, stochastic_rounding, weighted, row_wise, mixed, use_cache, cache_algorithm, pooling_mode, use_cpu, ): # NOTE: cache is not applicable to CPU version. assume(not use_cpu or not use_cache) exact = True # Only exact sparse optimizers are supported # NOTE: torch.autograd.gradcheck() is too time-consuming for CPU version # so we have to limit (T * B * L * D)! assume(not use_cpu or T * B * L * D <= 1024) assume(not (use_cpu and weights_precision == SparseType.FP16)) assume( pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM or not weighted) mode = ("sum" if pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM else "mean") # stochastic rounding only implemented for rowwise assume(not stochastic_rounding or row_wise) # need unique indices for non-exact tests assume(exact or int(10**log_E) > int(2.1 * B * L)) # only row-wise supports caching assume(row_wise or not use_cache) E = int(10**log_E) if use_cpu: D = (D + 15) // 16 * 4 else: D = D * 4 if not mixed: Ds = [D] * T Es = [E] * T else: Ds = [ div_round_up( np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] Es = [ np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) ] compute_device = split_table_batched_embeddings_ops.ComputeDevice.CUDA if use_cpu: managed = [ split_table_batched_embeddings_ops.EmbeddingLocation.HOST ] * T compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU elif use_cache: managed = [ split_table_batched_embeddings_ops.EmbeddingLocation. MANAGED_CACHING ] * T if mixed: average_D = sum(Ds) // T for t, d in enumerate(Ds): managed[t] = (split_table_batched_embeddings_ops. EmbeddingLocation.DEVICE if d < average_D else managed[t]) else: managed = [ np.random.choice([ split_table_batched_embeddings_ops.EmbeddingLocation. DEVICE, split_table_batched_embeddings_ops.EmbeddingLocation. MANAGED, ]) for _ in range(T) ] bs = [ to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu) for (E, D) in zip(Es, Ds) ] if weights_precision == SparseType.FP16 and not use_cpu: # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16. bs = [b.half() for b in bs] feature_table_map = list(range(T)) if exact: # autograd with shared embedding only works for exact table_to_replicate = T // 2 bs.insert(table_to_replicate, bs[table_to_replicate]) feature_table_map.insert(table_to_replicate, table_to_replicate) xs = [ to_device( torch.from_numpy( np.random.choice(range(Es[t]), size=(B, L), replace=exact).astype(np.int64)), use_cpu, ) for t in feature_table_map ] xws = [ to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(len(xs)) ] xws_acc_type = copy.deepcopy(xws) if weights_precision == SparseType.FP16 and not use_cpu: # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16. xws = [xw.half() for xw in xws] fs = ([b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)] if not weighted else [ b_indices( b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu) for (b, x, xw) in zip(bs, xs, xws) ]) gos = [torch.randn_like(f) for f in fs] [f.backward(go) for (f, go) in zip(fs, gos)] # do SGD update lr = 0.5 eps = 0.2 cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen( [(E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)], feature_table_map=feature_table_map, optimizer=OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD, learning_rate=lr, eps=eps, weights_precision=weights_precision, stochastic_rounding=stochastic_rounding, pooling_mode=pooling_mode, ) if exact: del bs[table_to_replicate] for t in range(T): cc.split_embedding_weights()[t].data.copy_(bs[t].weight) x = torch.cat([x.view(1, B, L) for x in xs], dim=0) xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0) (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu) fc2 = (cc(indices, offsets) if not weighted else cc( indices, offsets, to_device(xw.contiguous().view(-1), use_cpu))) fc2.backward(torch.cat([go.view(B, -1) for go in gos], dim=1)) cc.flush() split_optimizer_states = [s for (s, ) in cc.split_optimizer_states()] for t in range(T): ref_optimizer_state = bs[t].weight.grad.float().to_dense().pow(2) torch.testing.assert_allclose( split_optimizer_states[t].float(), ref_optimizer_state.mean( dim=1) if row_wise else ref_optimizer_state, atol=5.0e-3 if weights_precision == SparseType.FP16 else 1.0e-4, rtol=5.0e-3 if weights_precision == SparseType.FP16 else 1.0e-4, ) for t in range(T): # optimizer_state = squares (no row-wise) or sum squares (row-wise) torch.testing.assert_allclose( cc.split_embedding_weights()[t].float(), torch.addcdiv( bs[t].weight.float(), value=-lr, tensor1=bs[t].weight.grad.float().to_dense(), tensor2=split_optimizer_states[t].float().sqrt_().add_( eps).view(Es[t], 1 if row_wise else Ds[t]), ), atol=5.0e-3 if weights_precision == SparseType.FP16 else 1.0e-4, rtol=5.0e-3 if weights_precision == SparseType.FP16 else 1.0e-4, ) if use_cpu: D_gradcheck = (D_gradcheck + 15) // 16 * 4 else: D_gradcheck = D_gradcheck * 4 cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen( [(E, D_gradcheck, M, compute_device) for (E, M) in zip(Es, managed)], feature_table_map=feature_table_map, optimizer=OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD, learning_rate=0.0, eps=eps, weights_precision=weights_precision, stochastic_rounding=stochastic_rounding, # NOTE: only SUM pooling can work with per_sample_weights! pooling_mode=split_table_batched_embeddings_ops.PoolingMode.SUM, ) if use_cpu: # NOTE: GPU version of SplitTableBatchedEmbeddingBagsCodegen doesn't support double. cc = cc.double() per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu) if use_cpu: per_sample_weights = per_sample_weights.double() per_sample_weights.requires_grad = True indices.requires_grad = False offsets.requires_grad = False for param in cc.parameters(): param.requires_grad = False torch.autograd.gradcheck(cc, (indices, offsets, per_sample_weights)) per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu) if use_cpu: per_sample_weights = per_sample_weights.double() per_sample_weights.requires_grad = True indices.requires_grad = False offsets.requires_grad = False for param in cc.parameters(): param.requires_grad = False y = cc(indices, offsets, per_sample_weights) y.sum().backward() indice_weight_grad_all = per_sample_weights.grad.clone().cpu() T_ = len(xws) feature_requires_grad = to_device( torch.tensor(np.random.choice([0, 1], replace=True, size=(T_, ))).int(), use_cpu, ) per_sample_weights = per_sample_weights.detach().clone() per_sample_weights.requires_grad = True y = cc( indices, offsets, per_sample_weights, feature_requires_grad=feature_requires_grad, ) y.sum().backward() indice_weight_grad_mask = per_sample_weights.grad.clone().cpu() for t in range(T_): if feature_requires_grad[t]: torch.testing.assert_allclose( indice_weight_grad_mask.view(T_, B, L)[t], indice_weight_grad_all.view(T_, B, L)[t], ) else: torch.testing.assert_allclose( indice_weight_grad_mask.view(T_, B, L)[t], torch.zeros_like( indice_weight_grad_mask.view(T_, B, L)[t]), )
def test_backward_sgd( # noqa C901 self, T, D, B, log_E, L, weights_precision, weighted, mixed, use_cache, cache_algorithm, long_segments, pooling_mode, use_cpu, ): # NOTE: cache is not applicable to CPU version. assume(not use_cpu or not use_cache) # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! assume(not use_cpu or T * B * L * D <= 2048) assume(not (use_cpu and weights_precision == SparseType.FP16)) assume( pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM or not weighted) mode = ("sum" if pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM else "mean") E = int(10**log_E) if use_cpu: D = (D + 15) // 16 * 4 else: D = D * 4 if not mixed: Ds = [D] * T Es = [E] * T else: Ds = [ div_round_up( np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] Es = [ np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) ] compute_device = split_table_batched_embeddings_ops.ComputeDevice.CUDA if use_cpu: managed = [ split_table_batched_embeddings_ops.EmbeddingLocation.HOST ] * T compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU elif use_cache: managed = [ split_table_batched_embeddings_ops.EmbeddingLocation. MANAGED_CACHING ] * T if mixed: average_D = sum(Ds) // T for t, d in enumerate(Ds): managed[t] = (split_table_batched_embeddings_ops. EmbeddingLocation.DEVICE if d < average_D else managed[t]) else: managed = [ np.random.choice([ split_table_batched_embeddings_ops.EmbeddingLocation. DEVICE, split_table_batched_embeddings_ops.EmbeddingLocation. MANAGED, ]) for _ in range(T) ] bs = [ to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu) for (E, D) in zip(Es, Ds) ] if weights_precision == SparseType.FP16 and not use_cpu: # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16. bs = [b.half() for b in bs] feature_table_map = list(range(T)) table_to_replicate = T // 2 bs.insert(table_to_replicate, bs[table_to_replicate]) feature_table_map.insert(table_to_replicate, table_to_replicate) xs = [ to_device( torch.from_numpy( np.random.choice(range(Es[t]), size=(B, L), replace=True).astype(np.int64)), use_cpu, ) for t in feature_table_map ] if long_segments and L > 0: for x in xs: x[:, 0] = 0 xws = [ to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(len(xs)) ] xws_acc_type = copy.deepcopy(xws) if weights_precision == SparseType.FP16 and not use_cpu: # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16. xws = [xw.half() for xw in xws] fs = ([b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)] if not weighted else [ b_indices( b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu) for (b, x, xw) in zip(bs, xs, xws) ]) gos = [torch.randn_like(f) for f in fs] [f.backward(go) for (f, go) in zip(fs, gos)] # do SGD update lr = 0.05 del bs[table_to_replicate] new_weights = [(b.weight - b.weight.grad * lr) for b in bs] cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen( [(E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)], optimizer=OptimType.EXACT_SGD, feature_table_map=feature_table_map, learning_rate=0.05, weights_precision=weights_precision, cache_algorithm=cache_algorithm, pooling_mode=pooling_mode, ) for t in range(T): cc.split_embedding_weights()[t].data.copy_(bs[t].weight) x = torch.cat([x.view(1, B, L) for x in xs], dim=0) xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0) (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu) fc2 = (cc(indices, offsets) if not weighted else cc( indices, offsets, to_device(xw.contiguous().view(-1), use_cpu))) goc = torch.cat([go.view(B, -1) for go in gos], dim=1).contiguous() fc2.backward(goc) if use_cache: cc.flush() for t in range(T): torch.testing.assert_allclose( cc.split_embedding_weights()[t], new_weights[t].half() if weights_precision == SparseType.FP16 and not use_cpu else new_weights[t], atol=(1.0e-2 if long_segments else 5.0e-3) if weights_precision == SparseType.FP16 else 1.0e-5, rtol=2.0e-2 if weights_precision == SparseType.FP16 else 1.0e-5, )
def test_forward( self, T, D, B, log_E, L, weights_precision, weighted, mixed, use_cache, cache_algorithm, pooling_mode, use_cpu, ): # NOTE: cache is not applicable to CPU version. assume(not use_cpu or not use_cache) # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! assume(not use_cpu or T * B * L * D <= 2048) assume(not (use_cpu and weights_precision == SparseType.FP16)) assume( pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM or not weighted) mode = ("sum" if pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM else "mean") E = int(10**log_E) if use_cpu: D = (D + 15) // 16 * 4 else: D = D * 4 if not mixed: Ds = [D] * T Es = [int(1e4)] * T else: Ds = [ div_round_up( np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] Es = [ np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) ] compute_device = split_table_batched_embeddings_ops.ComputeDevice.CUDA if use_cpu: managed = [ split_table_batched_embeddings_ops.EmbeddingLocation.HOST ] * T compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU elif use_cache: managed = [ split_table_batched_embeddings_ops.EmbeddingLocation. MANAGED_CACHING ] * T if mixed: average_D = sum(Ds) // T for t, d in enumerate(Ds): managed[t] = (split_table_batched_embeddings_ops. EmbeddingLocation.DEVICE if d < average_D else managed[t]) else: managed = [ np.random.choice([ split_table_batched_embeddings_ops.EmbeddingLocation. DEVICE, split_table_batched_embeddings_ops.EmbeddingLocation. MANAGED, ]) for _ in range(T) ] bs = [ to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu) for (E, D) in zip(Es, Ds) ] if weights_precision == SparseType.FP16 and not use_cpu: # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16. bs = [b.half() for b in bs] xs = [ to_device(torch.randint(low=0, high=e, size=(B, L)), use_cpu) for e in Es ] xws = [to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(T)] xws_acc_type = copy.deepcopy(xws) if weights_precision == SparseType.FP16 and not use_cpu: # NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16. xws = [xw.half() for xw in xws] fs = ([b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)] if not weighted else [ b_indices( b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu) for (b, x, xw) in zip(bs, xs, xws) ]) f = torch.cat([f.view(B, -1) for f in fs], dim=1) cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen( embedding_specs=[( E, D, split_table_batched_embeddings_ops.EmbeddingLocation(M), compute_device, ) for (E, D, M) in zip(Es, Ds, managed)], weights_precision=weights_precision, optimizer=OptimType.EXACT_SGD, learning_rate=0.05, cache_algorithm=cache_algorithm, pooling_mode=pooling_mode, ) # NOTE: test TorchScript-compatible! cc = torch.jit.script(cc) for t in range(T): cc.split_embedding_weights()[t].data.copy_(bs[t].weight) x = torch.cat([x.view(1, B, L) for x in xs], dim=0) xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0) (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu) fc2 = (cc(indices, offsets) if not weighted else cc( indices, offsets, to_device(xw.contiguous().view(-1), use_cpu))) torch.testing.assert_allclose( fc2.float(), f.float(), atol=8.0e-3 if weights_precision == SparseType.FP16 else 1.0e-5, rtol=8.0e-3 if weights_precision == SparseType.FP16 else 1.0e-5, )