def test_spec_aug_kernel_mask_value(self): numba_utils.skip_numba_cuda_test_if_unsupported( __NUMBA_MINIMUM_VERSION__) cfg = get_cfg(seed=0) cfg.freq_masks = 2 cfg.time_masks = 10 cfg.mask_value = -1.0 data = prepare_data(**cfg) launch_kernel(data, cfg) x, x_len, sh = data['x'], data['x_len'], data['sh'] # Assert freq masks are correct for bidx in range(sh[0]): for f_start, f_len in zip(data['freq_starts'][bidx], data['freq_lengths'][bidx]): freq_mask_check(x, x_len, f_start, f_len, mask_value=cfg.mask_value, bidx=bidx) # Assert time masks are correct for bidx in range(sh[0]): for t_start, t_len in zip(data['time_starts'][bidx], data['time_lengths'][bidx]): time_mask_check(x, x_len, t_start, t_len, mask_value=cfg.mask_value, bidx=bidx)
def test_compute_costs_data(self, batch_size, fastemit_lambda): numba_utils.skip_numba_cuda_test_if_unsupported( __NUMBA_MINIMUM_VERSION__) x = np.full([batch_size], fill_value=0.0) # np.random.rand(8192) y = np.random.randn(batch_size) # np.random.rand(8192) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) y_c = cuda.to_device(y, stream=stream) # call kernel threads_per_block = min(x.shape[0], 32) blocks_per_grid = (x.shape[0] + (threads_per_block - 1)) // threads_per_block # Kernel call (source, dest, extra_args_...) rnnt_helper.compute_costs_data[blocks_per_grid, threads_per_block, stream](y_c, x_c, fastemit_lambda) # sync kernel stream.synchronize() x_new = x_c.copy_to_host(stream=stream) del x_c, y_c res = -(y.copy()) res *= 1.0 + fastemit_lambda for i in range(len(x_new)): assert x_new[i] == res[i], f"index failed {i}"
def test_exponential(self): numba_utils.skip_numba_cuda_test_if_unsupported( __NUMBA_MINIMUM_VERSION__) # wrapper kernel for device function that is tested @cuda.jit def _kernel(x): x_pos = cuda.grid(1) if x_pos < x.shape[0]: x[x_pos] = rnnt_helper.exponential(x[x_pos]) x = np.random.rand(8) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) # call kernel threads_per_block = global_constants.threads_per_block() blocks_per_grid = (x.shape[0] + threads_per_block - 1) // threads_per_block _kernel[blocks_per_grid, threads_per_block, stream](x_c) # sync kernel stream.synchronize() x_new = x_c.copy_to_host(stream=stream) del x_c y = np.exp(x) for i in range(len(x_new)): assert (x_new[i] - y[i]) < 1e-4
def test_log_sum_exp_neg_inf(self): numba_utils.skip_numba_cuda_test_if_unsupported( __NUMBA_MINIMUM_VERSION__) # wrapper kernel for device function that is tested @cuda.jit def _kernel(x, y): x_pos = cuda.grid(1) if x_pos < x.shape[0] and x_pos < y.shape[0]: x[x_pos] = rnnt_helper.log_sum_exp(x[x_pos], y[x_pos]) x = np.asarray([global_constants.FP32_NEG_INF] * 8) y = np.ones([len(x)]) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) y_c = cuda.to_device(y, stream=stream) # call kernel threads_per_block = global_constants.threads_per_block() blocks_per_grid = (x.shape[0] + threads_per_block - 1) // threads_per_block _kernel[blocks_per_grid, threads_per_block, stream](x_c, y_c) # sync kernel stream.synchronize() x_new = x_c.copy_to_host(stream=stream) del x_c, y_c assert np.allclose(x_new, np.ones_like(x_new), atol=1e-5)
def test_log_plus(self): numba_utils.skip_numba_cuda_test_if_unsupported( __NUMBA_MINIMUM_VERSION__) # wrapper kernel for device function that is tested @cuda.jit def _kernel(x, y): x_pos = cuda.grid(1) if x_pos < x.shape[0] and x_pos < y.shape[0]: x[x_pos] = rnnt_helper.log_plus(x[x_pos], y[x_pos]) x = np.full([8], fill_value=10.0) # np.random.rand(8192) y = np.full([8], fill_value=2.0) # np.random.rand(8192) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) y_c = cuda.to_device(y, stream=stream) # call kernel threads_per_block = global_constants.threads_per_block() blocks_per_grid = (x.shape[0] + threads_per_block - 1) // threads_per_block _kernel[blocks_per_grid, threads_per_block, stream](x_c, y_c) # sync kernel stream.synchronize() x_new = x_c.copy_to_host(stream=stream) del x_c, y_c z = np.log1p(np.exp(-np.fabs(x - y))) + np.maximum(x, y) for i in range(len(x_new)): assert x_new[i] == z[i]
def test_log_sum_exp(self): numba_utils.skip_numba_cuda_test_if_unsupported( __NUMBA_MINIMUM_VERSION__) # wrapper kernel for device function that is tested @cuda.jit def _kernel(x, y): x_pos = cuda.grid(1) if x_pos < x.shape[0] and x_pos < y.shape[0]: x[x_pos] = rnnt_helper.log_sum_exp(x[x_pos], y[x_pos]) x = np.zeros([8]) # np.random.rand(8192) y = np.ones([8]) # np.random.rand(8192) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) y_c = cuda.to_device(y, stream=stream) # call kernel threads_per_block = global_constants.threads_per_block() blocks_per_grid = (x.shape[0] + threads_per_block - 1) // threads_per_block _kernel[blocks_per_grid, threads_per_block, stream](x_c, y_c) # sync kernel stream.synchronize() x_new = x_c.copy_to_host(stream=stream) del x_c, y_c assert (x_new.sum() - 10.506093500145782) <= 1e-5
def test_case_small_fastemit_clamp(self, device, fastemit_lambda): if device == 'cuda': numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) GRAD_CLAMP = 0.1 acts = np.array( [ [ [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1], [0.1, 0.1, 0.2, 0.8, 0.1]], [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]], ] ] ) labels = [[1, 2]] fn_pt = RNNTLossNumba(blank=0, reduction='sum', fastemit_lambda=fastemit_lambda, clamp=GRAD_CLAMP) pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) fn_np = RNNTLoss_Numpy(blank=0, fastemit_lambda=fastemit_lambda, clamp=GRAD_CLAMP) np_cost, np_grads = wrap_and_call(fn_np, acts, labels, device) expected_cost = 4.495666 expected_cost += expected_cost * fastemit_lambda assert np.allclose(pt_cost, expected_cost, rtol=1e-6), "small_test costs mismatch." assert np.allclose(pt_cost, np_cost, rtol=1e-6), "small_test costs mismatch." assert np.allclose(pt_grads, np_grads), "small_test gradient mismatch."
def test_reduce_exp(self): numba_utils.skip_numba_cuda_test_if_unsupported( __NUMBA_MINIMUM_VERSION__) random = np.random.RandomState(0) original_shape = [1, 5, 4, 2] x = random.randn(*original_shape).reshape([-1]) dx = np.zeros_like(x) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) dx_c = cuda.to_device(dx, stream=stream) # call kernel cols = np.prod(original_shape[:3]) reduce.reduce_exp(x_c, dx_c, rows=original_shape[-1], cols=cols, minus=False, stream=stream) # sync kernel stream.synchronize() dx_result = dx_c.copy_to_host(stream=stream) del x_c, dx_c # collect results in first [B * T * U] values; for all V assert (dx_result[cols:] - dx[cols:]).sum() <= 1e-7 # make sure dx_result updates the [B * T * U] values assert np.abs(dx_result[:cols] - dx[:cols]).sum() > 0
def test_SpectrogramAugmentationr_numba_kernel(self, caplog): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) logging._logger.propagate = True original_verbosity = logging.get_verbosity() logging.set_verbosity(logging.DEBUG) caplog.set_level(logging.DEBUG) # Make sure constructor works instance1 = modules.SpectrogramAugmentation( freq_masks=10, time_masks=3, rect_masks=3, use_numba_spec_augment=True ) assert isinstance(instance1, modules.SpectrogramAugmentation) # Make sure forward doesn't throw with expected input instance0 = modules.AudioToMelSpectrogramPreprocessor(dither=0) input_signal = torch.randn(size=(8, 512)) length = torch.randint(low=161, high=500, size=[8]) res0 = instance0(input_signal=input_signal, length=length) res = instance1(input_spec=res0[0], length=length) assert res.shape == res0[0].shape # check tha numba kernel debug message indicates that it is available for use assert """Numba SpecAugment kernel is available""" in caplog.text logging._logger.propagate = False logging.set_verbosity(original_verbosity)
def test_case_small_random_fastemit_reg(self, device, fastemit_lambda): if device == 'cuda': numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) rng = np.random.RandomState(0) acts = rng.randn(1, 4, 3, 3) labels = [[1, 2]] fn_pt = RNNTLossNumba(blank=0, reduction='sum', fastemit_lambda=fastemit_lambda) pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) fn_np = RNNTLoss_Numpy(fastemit_lambda=fastemit_lambda) np_cost, np_grads = wrap_and_call(fn_np, acts, labels, device) assert np.allclose(pt_cost, np_cost, rtol=1e-6), "small_random_test costs mismatch." assert np.allclose(pt_grads, np_grads, atol=1e-5, rtol=1e-5), "small_random_test gradient mismatch."
def test_spec_aug_kernel_no_freq_time_mask(self): numba_utils.skip_numba_cuda_test_if_unsupported( __NUMBA_MINIMUM_VERSION__) cfg = get_cfg(seed=0) cfg.freq_masks = 0 cfg.time_masks = 0 data = prepare_data(**cfg) x, x_len, sh = data['x'], data['x_len'], data['sh'] x_copy = x.clone() launch_kernel(data, cfg) # Assert no data edits occured assert (data['x'] - x_copy).abs().mean() <= 1e-9
def test_spec_aug_kernel_grad(self): numba_utils.skip_numba_cuda_test_if_unsupported( __NUMBA_MINIMUM_VERSION__) cfg = get_cfg(seed=0) cfg.freq_masks = 2 cfg.time_masks = 10 data = prepare_data(**cfg) launch_kernel(data, cfg) result = data['x'] # inplace modification via kernel y = torch.ones_like(result, requires_grad=True) z = y + result z.mean().backward() assert y.grad is not None
def test_case_small(self, device): if device == 'cuda': numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) acts = np.array( [ [ [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1], [0.1, 0.1, 0.2, 0.8, 0.1]], [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]], ] ] ) labels = [[1, 2]] fn_pt = RNNTLossNumba(blank=0, reduction='sum') pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) fn_np = RNNTLoss_Numpy() np_cost, np_grads = wrap_and_call(fn_np, acts, labels, device) expected_cost = 4.495666 expected_grads = np.array( [ [ [ [-0.13116688, -0.3999269, 0.17703125, 0.17703125, 0.17703125], [-0.18572757, 0.12247056, -0.18168412, 0.12247056, 0.12247056], [-0.32091254, 0.06269141, 0.06928472, 0.12624499, 0.06269141], ], [ [0.05456069, -0.21824276, 0.05456069, 0.05456069, 0.05456069], [0.12073959, 0.12073959, -0.48295835, 0.12073959, 0.12073959], [-0.6925882, 0.16871116, 0.18645467, 0.16871116, 0.16871116], ], ] ] ) assert np.allclose(pt_cost, expected_cost, rtol=1e-6), "small_test costs mismatch." assert np.allclose(pt_grads, expected_grads), "small_test gradient mismatch." assert np.allclose(pt_cost, np_cost, rtol=1e-6), "small_test costs mismatch." assert np.allclose(pt_grads, np_grads), "small_test gradient mismatch."
def test_spec_aug_kernel_large_batch(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported( __NUMBA_MINIMUM_VERSION__) # Change max threads per block temporarily original_buffer = spec_aug_numba.MAX_THREAD_BUFFER spec_aug_numba.MAX_THREAD_BUFFER = 4 cfg = get_cfg(seed=0, dtype=dtype) cfg.freq_masks = 2 cfg.time_masks = 10 cfg.b = spec_aug_numba.MAX_THREAD_BUFFER + 1 data = prepare_data(**cfg) launch_kernel(data, cfg) x, x_len, sh = data['x'], data['x_len'], data['sh'] # Assert freq masks are correct for bidx in range(sh[0]): for f_start, f_len in zip(data['freq_starts'][bidx], data['freq_lengths'][bidx]): freq_mask_check(x, x_len, f_start, f_len, mask_value=cfg.mask_value, bidx=bidx) # Assert time masks are correct for bidx in range(sh[0]): for t_start, t_len in zip(data['time_starts'][bidx], data['time_lengths'][bidx]): time_mask_check(x, x_len, t_start, t_len, mask_value=cfg.mask_value, bidx=bidx) spec_aug_numba.MAX_THREAD_BUFFER = original_buffer
def test_case_large_random(self, device): if device == 'cuda': numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) rng = np.random.RandomState(0) acts = rng.randn(4, 8, 11, 5) labels = [ [1, 2, 4, 3, 2, 2, 1, 1, 1, 1], [3, 2, 2, 3, 4, 1, 1, 1, 1, 1], [4, 4, 1, 2, 1, 3, 4, 3, 1, 2], [1, 1, 2, 1, 2, 3, 3, 1, 1, 1], ] fn_pt = RNNTLossNumba(blank=0, reduction='sum') pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) fn_np = RNNTLoss_Numpy() np_cost, np_grads = wrap_and_call(fn_np, acts, labels, device) assert np.allclose(pt_cost, np_cost, atol=1e-5, rtol=1e-3), "large_random_test costs mismatch." assert np.allclose(pt_grads, np_grads, atol=1e-5, rtol=1e-3), "large_random_test gradient mismatch."
def test_case_big_tensor(self, device): if device == 'cuda': numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) # minibatch x T x U x alphabet_size activations = [ [ [ [0.06535690384862791, 0.7875301411923206, 0.08159176605666074], [0.5297155426466327, 0.7506749639230854, 0.7541348379087998], [0.6097641124736383, 0.8681404965673826, 0.6225318186056529], ], [ [0.6685222872103057, 0.8580392805336061, 0.16453892311765583], [0.989779515236694, 0.944298460961015, 0.6031678586829663], [0.9467833543605416, 0.666202507295747, 0.28688179752461884], ], [ [0.09418426230195986, 0.3666735970751962, 0.736168049462793], [0.1666804425271342, 0.7141542198635192, 0.3993997272216727], [0.5359823524146038, 0.29182076440286386, 0.6126422611507932], ], [ [0.3242405528768486, 0.8007644367291621, 0.5241057606558068], [0.779194617063042, 0.18331417220174862, 0.113745182072432], [0.24022162381327106, 0.3394695622533106, 0.1341595066017014], ], ], [ [ [0.5055615569388828, 0.051597282072282646, 0.6402903936686337], [0.43073311517251, 0.8294731834714112, 0.1774668847323424], [0.3207001991262245, 0.04288308912457006, 0.30280282975568984], ], [ [0.6751777088333762, 0.569537369330242, 0.5584738347504452], [0.08313242153985256, 0.06016544344162322, 0.10795752845152584], [0.7486153608562472, 0.943918041459349, 0.4863558118797222], ], [ [0.4181986264486809, 0.6524078485043804, 0.024242983423721887], [0.13458171554507403, 0.3663418070512402, 0.2958297395361563], [0.9236695822497084, 0.6899291482654177, 0.7418981733448822], ], [ [0.25000547599982104, 0.6034295486281007, 0.9872887878887768], [0.5926057265215715, 0.8846724004467684, 0.5434495396894328], [0.6607698886038497, 0.3771277082495921, 0.3580209022231813], ], ], ] expected_costs = [4.2806528590890736, 3.9384369822503591] expected_grads = [ [ [ [-1.86843902e-01, -6.25548810e-02, 2.49398798e-01], [-2.03376666e-01, 2.02399328e-01, 9.77333169e-04], [-1.41016081e-01, 7.91234672e-02, 6.18926100e-02], ], [ [-1.15517676e-02, -8.12802389e-02, 9.28319991e-02], [-1.54257029e-01, 2.29432687e-01, -7.51756504e-02], [-2.46593088e-01, 1.46404594e-01, 1.00188486e-01], ], [ [-1.29182907e-02, -6.15932420e-02, 7.45115355e-02], [-5.59857301e-02, 2.19830811e-01, -1.63845062e-01], [-4.97626871e-01, 2.09239945e-01, 2.88386941e-01], ], [ [1.36048580e-02, -3.02196294e-02, 1.66147724e-02], [1.13924511e-01, 6.27811998e-02, -1.76705718e-01], [-6.67078257e-01, 3.67658824e-01, 2.99419403e-01], ], ], [ [ [-3.56343776e-01, -5.53474613e-02, 4.11691219e-01], [-9.69219357e-02, 2.94591039e-02, 6.74628317e-02], [-6.35175705e-02, 2.76544970e-02, 3.58630717e-02], ], [ [-1.54499024e-01, -7.39420280e-02, 2.28441030e-01], [-1.66789949e-01, -8.78955179e-05, 1.66877866e-01], [-1.72369644e-01, 1.05565332e-01, 6.68043196e-02], ], [ [2.38748826e-02, -1.18255816e-01, 9.43809375e-02], [-1.04707085e-01, -1.08934477e-01, 2.13641584e-01], [-3.69844258e-01, 1.80118099e-01, 1.89726159e-01], ], [ [2.57137045e-02, -7.94617534e-02, 5.37480488e-02], [1.22328237e-01, -2.38788679e-01, 1.16460443e-01], [-5.98686993e-01, 3.02203178e-01, 2.96483815e-01], ], ], ] activations = np.array(activations) labels = [[1, 2], [1, 1]] fn_pt = RNNTLossNumba(blank=0, reduction='sum') pt_costs, pt_grads = wrap_and_call(fn_pt, activations, labels, device) fn_np = RNNTLoss_Numpy() np_costs, np_grads = wrap_and_call(fn_np, activations, labels, device) assert np.allclose(pt_costs, sum(expected_costs)), "big_test average costs mismatch." assert np.allclose(pt_grads, expected_grads, rtol=1e-3), "big_test grads for average cost mismatch." assert np.allclose(pt_costs, np_costs), "big_test average costs mismatch." assert np.allclose(pt_grads, np_grads, rtol=1e-3), "big_test grads for average cost mismatch."
def test_compute_grads_kernel_clamp(self): numba_utils.skip_numba_cuda_test_if_unsupported( __NUMBA_MINIMUM_VERSION__) fastemit_lambda = 0.0 clamp = 0.1 random = np.random.RandomState(0) original_shape = [1, 5, 11, 3] B, T, U, V = original_shape # Numpy kernel x = random.randn(*original_shape) labels = torch.from_numpy( np.array([[1, 1, 1, 2, 2, 2, 1, 2, 2, 1]], dtype=np.int32)) # [1, 10] audio_len = torch.from_numpy(np.array([T], dtype=np.int32)) label_len = torch.from_numpy(np.array([U - 1], dtype=np.int32)) blank_idx = 0 x_np = torch.from_numpy(x) x_np.requires_grad_(True) """ Here we will directly utilize the numpy variant of the loss without explicitly calling the numpy functions for alpha, beta and grads. This is because the grads returned by the rnnt_numpy.transduce_batch() are : d/dx (alpha + beta alignment)(log_softmax(x)). But according to the chain rule, we'd still need to compute the gradient of log_softmax(x) and update the alignments by hand. Instead, we will rely on pytorch to compute the gradient of the log_softmax(x) step and propagate it backwards. """ loss_func = rnnt_numpy.RNNTLoss(blank_idx, fastemit_lambda=fastemit_lambda, clamp=clamp) loss_val = loss_func(x_np, labels, audio_len, label_len) loss_val.sum().backward() true_grads = x_np.grad # Pytorch kernel device = torch.device('cuda') if hasattr(cuda, 'external_stream'): stream = cuda.external_stream( torch.cuda.current_stream(device).cuda_stream) else: stream = cuda.default_stream() x_c = torch.tensor(x, device=device, dtype=torch.float32) labels_c = torch.tensor(labels, device=device, dtype=torch.int32) # Allocate workspace memory denom = torch.zeros(B * T * U, device=device, dtype=x_c.dtype) alphas = torch.zeros(B * T * U, device=device, dtype=x_c.dtype) betas = torch.zeros(B * T * U, device=device, dtype=x_c.dtype) llForward = torch.zeros(B, device=device, dtype=x_c.dtype) llBackward = torch.zeros(B, device=device, dtype=x_c.dtype) input_lengths = torch.tensor([T], dtype=torch.int32, device=device) label_lengths = torch.tensor([len(labels[0])], dtype=torch.int32, device=device) # certify input data certify_inputs(x_c, labels_c, input_lengths, label_lengths) # flatten activation tensor (for pointer based indexing) x_c = x_c.view([-1]) grads = torch.zeros_like(x_c, requires_grad=False) # call kernel # log softmax reduction reduce.reduce_max(x_c, denom, rows=V, cols=B * T * U, minus=False, stream=stream) reduce.reduce_exp(x_c, denom, rows=V, cols=B * T * U, minus=True, stream=stream) # alpha kernel gpu_rnnt_kernel.compute_alphas_kernel[B, U, stream, 0]( x_c, denom, alphas, llForward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, ) # beta kernel gpu_rnnt_kernel.compute_betas_kernel[B, U, stream, 0]( x_c, denom, betas, llBackward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, ) # gamma kernel grad_blocks_per_grid = B * T * U grad_threads_per_block = gpu_rnnt_kernel.GPU_RNNT_THREAD_SIZE gpu_rnnt_kernel.compute_grad_kernel[grad_blocks_per_grid, grad_threads_per_block, stream, 0]( grads, x_c, denom, alphas, betas, llForward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, fastemit_lambda, clamp, ) # sync kernel stream.synchronize() # reshape grads grads = grads.view([B, T, U, V]) diff = true_grads - grads[0].cpu().numpy() assert np.abs(diff).mean() <= 1e-5 assert np.square(diff).mean() <= 1e-10
def test_compute_alphas_kernel(self): numba_utils.skip_numba_cuda_test_if_unsupported( __NUMBA_MINIMUM_VERSION__) random = np.random.RandomState(0) original_shape = [1, 5, 11, 3] B, T, U, V = original_shape # Numpy kernel x = random.randn(*original_shape) labels = np.array([[1, 1, 1, 2, 2, 2, 1, 2, 2, 1]]) # [1, 10] label_len = len(labels[0]) + 1 blank_idx = 0 x_np = log_softmax(x, axis=-1) ground_alphas, ground_log_likelihood = rnnt_numpy.forward_pass( x_np[0, :, :label_len, :], labels[0, :label_len - 1], blank_idx) # Pytorch kernel device = torch.device('cuda') if hasattr(cuda, 'external_stream'): stream = cuda.external_stream( torch.cuda.current_stream(device).cuda_stream) else: stream = cuda.default_stream() x_c = torch.tensor(x, device=device, dtype=torch.float32) labels_c = torch.tensor(labels, device=device, dtype=torch.int32) # Allocate workspace memory denom = torch.zeros(B * T * U, device=device, dtype=x_c.dtype) alphas = torch.zeros(B * T * U, device=device, dtype=x_c.dtype) llForward = torch.zeros(B, device=device, dtype=x_c.dtype) input_lengths = torch.tensor([T], dtype=torch.int32, device=device) label_lengths = torch.tensor([len(labels[0])], dtype=torch.int32, device=device) # certify input data certify_inputs(x_c, labels_c, input_lengths, label_lengths) # flatten activation tensor (for pointer based indexing) x_c = x_c.view([-1]) # call kernel # log softmax reduction reduce.reduce_max(x_c, denom, rows=V, cols=B * T * U, minus=False, stream=stream) reduce.reduce_exp(x_c, denom, rows=V, cols=B * T * U, minus=True, stream=stream) # alpha kernel gpu_rnnt_kernel.compute_alphas_kernel[B, U, stream, 0]( x_c, denom, alphas, llForward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, ) # sync kernel stream.synchronize() # reshape alphas alphas = alphas.view([B, T, U]) diff = ground_alphas - alphas[0].cpu().numpy() assert np.abs(diff).mean() <= 1e-5 assert np.square(diff).mean() <= 1e-10 ll_diff = ground_log_likelihood - llForward[0].cpu().numpy() assert np.abs(ll_diff).mean() <= 1e-5 assert np.square(ll_diff).mean() <= 1e-10
def test_case_small_random_accumulated(self, device): if device == 'cuda': numba_utils.skip_numba_cuda_test_if_unsupported( __NUMBA_MINIMUM_VERSION__) torch.manual_seed(0) base_layer = torch.randn(3, 5, requires_grad=True) mid1 = torch.randn(1, 4, 3, 3, requires_grad=True) labels1 = [[1, 3]] mid2 = torch.randn(1, 6, 5, 3, requires_grad=True) labels2 = [[1, 2, 3, 4]] def zero_grad(): if base_layer.grad is not None: base_layer.grad = None if mid1.grad is not None: mid1.grad = None if mid2.grad is not None: mid2.grad = None fn_pt = RNNTLossNumba(blank=0, reduction='sum') fn_np = RNNTLoss_Numpy() # run 1 acts1 = torch.matmul(mid1, base_layer) # [1, 4, 3, 5] pt_cost1, _ = wrap_and_call(fn_pt, acts1, labels1, device) pt_grads1 = base_layer.grad.clone().cpu().numpy() zero_grad() acts1 = torch.matmul(mid1, base_layer) # [1, 4, 3, 5] np_cost1, _ = wrap_and_call(fn_np, acts1, labels1, device) np_grads1 = base_layer.grad.clone().cpu().numpy() zero_grad() assert np.allclose(pt_grads1, np_grads1, atol=1e-6) # run 2 acts2 = torch.matmul(mid2, base_layer) # [1, 4, 3, 5] pt_cost2, _ = wrap_and_call(fn_pt, acts2, labels2, device) pt_grads2 = base_layer.grad.clone().cpu().numpy() zero_grad() acts2 = torch.matmul(mid2, base_layer) # [1, 4, 3, 5] np_cost2, _ = wrap_and_call(fn_np, acts2, labels2, device) np_grads2 = base_layer.grad.clone().cpu().numpy() zero_grad() assert np.allclose(pt_grads2, np_grads2, atol=1e-6) # run 1 + 2 acts1 = torch.matmul(mid1, base_layer) # [1, 4, 3, 5] pt_cost1, _ = wrap_and_call(fn_pt, acts1, labels1, device) acts2 = torch.matmul(mid2, base_layer) # [1, 6, 5, 5] pt_cost2, _ = wrap_and_call(fn_pt, acts2, labels2, device) pt_grads1_p_2 = base_layer.grad.clone().cpu().numpy() assert np.allclose(pt_grads1_p_2, np_grads1 + np_grads2, atol=1e-6)