def test_backward_dense(self, batch_size, pooling_factor, pooling_factor_std, tt_ndims): device = torch.device("cuda:0") torch.cuda.set_device(device) tt_p_shapes = [7, 9, 11, 5] tt_q_shapes = [3, 4, 5, 7] tt_ranks = [13, 12, 7] tt_p_shapes = tt_p_shapes[:tt_ndims] tt_q_shapes = tt_q_shapes[:tt_ndims] tt_ranks = tt_ranks[:(tt_ndims - 1)] num_embeddings = np.prod(np.array(tt_p_shapes)) embedding_dim = np.prod(np.array(tt_q_shapes)) _, indices, offsets, _ = generate_sparse_feature( batch_size, num_embeddings=num_embeddings, pooling_factor=float(pooling_factor), pooling_factor_std=float(pooling_factor_std), generate_scores=False, unary=False, unique=False, ) # create TT-Embedding op offsets = torch.tensor(offsets, dtype=torch.int64, device=device) indices = torch.tensor(indices, dtype=torch.int64, device=device) tt_emb = TTEmbeddingBag( num_embeddings=num_embeddings, embedding_dim=embedding_dim, tt_p_shapes=tt_p_shapes, tt_q_shapes=tt_q_shapes, tt_ranks=tt_ranks, sparse=False, weight_dist="uniform", ) tt_emb.to(device) emb = torch.nn.EmbeddingBag( num_embeddings, embedding_dim, sparse=True, mode="sum", _weight=tt_emb.full_weight(), include_last_offset=True, ) emb.to(device) d_output = torch.rand(batch_size, embedding_dim, device=device) * 0.1 tt_cores = [ tt.clone().detach().requires_grad_(True) for tt in tt_emb.tt_cores ] full_weight = tt_matrix_to_full(tt_p_shapes, tt_q_shapes, tt_ranks, tt_cores, [1, 0, 2, 3]) # tt_emb output = tt_emb(indices, offsets) output.backward(d_output) # reference output_ref = emb(indices.long(), offsets.long()) output_ref.backward(d_output) d_weight_ref = emb.weight.grad.to_dense() full_weight.backward(d_weight_ref) for i in range(tt_ndims): torch.testing.assert_allclose(tt_emb.tt_cores[i].grad, tt_cores[i].grad)
def test_forward(self, batch_size, pooling_factor, pooling_factor_std, tt_ndims): device = torch.device("cuda:0") torch.cuda.set_device(device) tt_p_shapes = [7, 9, 11, 5] tt_q_shapes = [3, 4, 5, 7] tt_ranks = [13, 12, 7] tt_p_shapes = tt_p_shapes[:tt_ndims] tt_q_shapes = tt_q_shapes[:tt_ndims] tt_ranks = tt_ranks[:(tt_ndims - 1)] num_embeddings = np.prod(np.array(tt_p_shapes)) embedding_dim = np.prod(np.array(tt_q_shapes)) _, indices, offsets, _ = generate_sparse_feature( batch_size, num_embeddings=num_embeddings, pooling_factor=float(pooling_factor), pooling_factor_std=float(pooling_factor_std), generate_scores=False, unary=False, unique=False, ) # create TT-Embedding op offsets = torch.tensor(offsets, dtype=torch.int64, device=device) indices = torch.tensor(indices, dtype=torch.int64, device=device) tt_emb = TTEmbeddingBag( num_embeddings=num_embeddings, embedding_dim=embedding_dim, tt_p_shapes=tt_p_shapes, tt_q_shapes=tt_q_shapes, tt_ranks=tt_ranks, sparse=False, weight_dist="uniform", ) tt_emb.to(device) emb = torch.nn.EmbeddingBag( num_embeddings, embedding_dim, sparse=True, mode="sum", _weight=tt_emb.full_weight(), include_last_offset=True, ) emb.to(device) # forward output = tt_emb(indices, offsets) output_ref = emb(indices.long(), offsets.long()) torch.testing.assert_allclose(output, output_ref)
def test_backward_adagrad(self, batch_size, pooling_factor, pooling_factor_std, tt_ndims): device = torch.device("cuda:0") torch.cuda.set_device(device) tt_p_shapes = [7, 9, 11, 5] tt_q_shapes = [3, 4, 5, 7] tt_ranks = [13, 12, 7] tt_p_shapes = tt_p_shapes[:tt_ndims] tt_q_shapes = tt_q_shapes[:tt_ndims] tt_ranks = tt_ranks[:(tt_ndims - 1)] num_embeddings = np.prod(np.array(tt_p_shapes)) embedding_dim = np.prod(np.array(tt_q_shapes)) learning_rate = 0.1 eps = 0.0001 _, indices, offsets, _ = generate_sparse_feature( batch_size, num_embeddings=num_embeddings, pooling_factor=float(pooling_factor), pooling_factor_std=float(pooling_factor_std), generate_scores=False, unary=False, unique=False, ) # create TT-Embedding op offsets = torch.tensor(offsets, dtype=torch.int64, device=device) indices = torch.tensor(indices, dtype=torch.int64, device=device) tt_emb = TTEmbeddingBag( num_embeddings=num_embeddings, embedding_dim=embedding_dim, tt_p_shapes=tt_p_shapes, tt_q_shapes=tt_q_shapes, tt_ranks=tt_ranks, sparse=True, optimizer=OptimType.EXACT_ADAGRAD, learning_rate=learning_rate, eps=eps, ) tt_emb.to(device) emb = torch.nn.EmbeddingBag( num_embeddings, embedding_dim, sparse=True, mode="sum", _weight=tt_emb.full_weight(), include_last_offset=True, ) emb.to(device) d_output = torch.rand(batch_size, embedding_dim, device=device) * 0.1 tt_cores = [ tt.clone().detach().requires_grad_(True) for tt in tt_emb.tt_cores ] full_weight = tt_matrix_to_full(tt_p_shapes, tt_q_shapes, tt_ranks, tt_cores, [1, 0, 2, 3]) # tt_emb output = tt_emb(indices, offsets) output.backward(d_output) # reference output_ref = emb(indices.long(), offsets.long()) output_ref.backward(d_output) d_weight_ref = emb.weight.grad.to_dense() full_weight.backward(d_weight_ref) new_optimizer_state = [] new_optimizer_state = [torch.mul(t.grad, t.grad) for t in tt_cores] new_tt_cores = [] new_tt_cores = [ (t - torch.div(t.grad * learning_rate, torch.sqrt(new_optimizer_state[i]) + eps)) for i, t in enumerate(tt_cores) ] for i in range(tt_ndims): torch.testing.assert_allclose(tt_emb.optimizer_state[i], new_optimizer_state[i]) torch.testing.assert_allclose(tt_emb.tt_cores[i], new_tt_cores[i])
def test_backward_table_batched(self, batch_size, pooling_factor, pooling_factor_std, tt_ndims, num_tables): device = torch.device("cuda:0") torch.cuda.set_device(device) tt_p_shapes = [7, 9, 11, 5] tt_q_shapes = [3, 4, 5, 7] tt_ranks = [13, 12, 7] tt_p_shapes = tt_p_shapes[:tt_ndims] tt_q_shapes = tt_q_shapes[:tt_ndims] tt_ranks = tt_ranks[:(tt_ndims - 1)] num_embeddings = np.prod(np.array(tt_p_shapes)) embedding_dim = np.prod(np.array(tt_q_shapes)) # create table batched tt embedding bag batched_tt_emb = TableBatchedTTEmbeddingBag( num_tables=num_tables, num_embeddings=num_embeddings, embedding_dim=embedding_dim, tt_p_shapes=tt_p_shapes, tt_q_shapes=tt_q_shapes, tt_ranks=tt_ranks, sparse=False, weight_dist="uniform", use_cache=False, ) batched_tt_emb.to(device) tt_embs = [] lengths_per_table = [] indices_per_table = [] inputs_per_table = [] for i in range(num_tables): lengths, indices, offsets, _ = generate_sparse_feature( batch_size, num_embeddings=num_embeddings, pooling_factor=float(pooling_factor), pooling_factor_std=float(pooling_factor_std), generate_scores=False, unary=False, unique=False, ) lengths_per_table.extend(lengths) indices_per_table.extend(indices) offsets = torch.tensor(offsets, dtype=torch.int64, device=device) indices = torch.tensor(indices, dtype=torch.int64, device=device) inputs_per_table.append((indices, offsets)) # create TT-Embedding op tt_emb = TTEmbeddingBag( num_embeddings=num_embeddings, embedding_dim=embedding_dim, tt_p_shapes=tt_p_shapes, tt_q_shapes=tt_q_shapes, tt_ranks=tt_ranks, sparse=False, weight_dist="uniform", use_cache=False, ) tt_emb.to(device) tt_embs.append(tt_emb) # copy tt cores to table batched for j, tt_core in enumerate(batched_tt_emb.tt_cores): tt_core.detach()[i].copy_(tt_emb.tt_cores[j][0].detach()) batched_offsets = torch.tensor([0] + list(np.cumsum(lengths_per_table)), dtype=torch.int64, device=device) batched_indices = torch.tensor(indices_per_table, dtype=torch.int64, device=device) batched_output = batched_tt_emb(batched_indices, batched_offsets) assert batched_offsets.numel() - 1 == batch_size * num_tables outputs = [ tt_embs[i](indices, offsets) for i, (indices, offsets) in enumerate(inputs_per_table) ] d_batched_output = ( torch.rand(num_tables, batch_size, embedding_dim, device=device) * 0.1) batched_output.backward(d_batched_output) for i, output in enumerate(outputs): output.backward(d_batched_output[i]) for j, tt_core in enumerate(tt_embs[i].tt_cores): torch.testing.assert_allclose( tt_core.grad[0], batched_tt_emb.tt_cores[j].grad[i])
def main( batch_size, iters, long_index, pooling_factor, p_shapes, q_shapes, ranks, sparse, optimizer, run_baseline, ): device = torch.device("cuda:0") torch.cuda.set_device(device) num_embeddings = np.prod(np.array(p_shapes)) embedding_dim = np.prod(np.array(q_shapes)) requests = generate_requests( iters, batch_size, 1, pooling_factor, num_embeddings, long_index ) nnz = batch_size * pooling_factor flop = ( q_shapes[0] * ranks[0] * q_shapes[1] * ranks[1] + q_shapes[0] * q_shapes[1] * ranks[1] * q_shapes[2] ) flop = 2.0 * nnz * flop * iters bw = 4.0 * nnz * embedding_dim * iters # create TT-Embedding op if optimizer == "sgd": optimizer = OptimType.SGD else: optimizer = OptimType.EXACT_ADAGRAD tt_emb = TTEmbeddingBag( num_embeddings=num_embeddings, embedding_dim=embedding_dim, tt_p_shapes=p_shapes, tt_q_shapes=q_shapes, tt_ranks=ranks, sparse=sparse, optimizer=optimizer, use_cache=True, ) tt_emb.to(device) logging.info(f"sparse: {sparse}, optimizer: {optimizer}") logging.info(f"p_shapes: {p_shapes}, " f"q_shapes: {q_shapes}, " f"ranks: {ranks}") logging.info( f"B: {batch_size}, E: {num_embeddings}, " f"D: {embedding_dim}, nnz: {nnz}" ) grad_output = torch.rand(batch_size, embedding_dim, device=device) * 0.1 time_per_iter = benchmark_requests( requests, lambda indices, offsets, _: tt_emb(indices, offsets).backward(grad_output), ) logging.info( f"TTEmbeddingBag FWD-BWD time/nnz: {time_per_iter / nnz * 1e6: .3f} usecs, " f"GFLOPS: {3.0 * flop / time_per_iter / 1e9: .3f}, " f"BW: {3.0 * bw / time_per_iter / 1e9: .3f}" ) # EmbeddingBag if run_baseline: emb = torch.nn.EmbeddingBag( num_embeddings, embedding_dim, sparse=True, mode="sum", include_last_offset=True, ) emb.to(device) time_per_iter = benchmark_requests( requests, lambda indices, offsets, _: emb(indices, offsets).backward(grad_output), ) logging.info( f"EmbeddingBag FWD-BWD time/nnz: {time_per_iter / nnz * 1e6: .3f} usecs, " f"BW: {3.0 * bw / time_per_iter / 1e9: .3f}" )