def test_1d_non_contiguous(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: for dtype in [torch.int32, torch.int64]: num_rows = torch.randint(20, 2000, size=(1,)).item() stride = torch.randint(2, num_rows // 10 + 1, size=(1,)).item() a = torch.randint(-1000, 1000, size=(num_rows,), dtype=dtype, device=device) a = a[::stride] assert a.is_contiguous() is False assert a.dtype == dtype b = torch.randint(-1, a.shape[0], size=(10000,), dtype=torch.int32, device=device) c = k2.index_select(a, b) padded_a = torch.cat([torch.tensor([0]).to(a), a]) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) assert torch.allclose(c, expected) for dtype in [torch.float32, torch.float64]: num_rows = torch.randint(20, 2000, size=(1,)).item() stride = torch.randint(2, num_rows // 10 + 1, size=(1,)).item() a_contiguous = torch.rand(num_rows, dtype=dtype, device=device) a = a_contiguous[::stride] a.requires_grad_(True) assert a.is_contiguous() is False assert a.dtype == dtype b = torch.randint(-1, a.shape[0], size=(10000,), dtype=torch.int32, device=device) c = k2.index_select(a, b) new_a = a_contiguous[::stride] new_a.requires_grad_(True) padded_a = torch.cat([torch.tensor([0]).to(a), new_a]) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) c.sum().backward() expected.sum().backward() assert torch.allclose(c, expected) assert torch.allclose(a.grad, new_a.grad)
def test_1d(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: for dtype in [torch.int32, torch.int64]: num_rows = torch.randint(1, 2000, size=(1,)).item() a = torch.randint(-1000, 1000, size=(num_rows,), dtype=dtype, device=device) assert a.is_contiguous() assert a.dtype == dtype num_indexes = torch.randint(1, 200000, size=(1,)).item() b = torch.randint(-1, num_rows, size=(num_indexes,), dtype=torch.int32, device=device) c = k2.index_select(a, b) assert c.dtype == a.dtype assert c.numel() == b.numel() padded_a = torch.cat([torch.tensor([0]).to(a), a]) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) assert torch.allclose(c, expected) for dtype in [torch.float32, torch.float64]: num_rows = torch.randint(1, 2000, size=(1,)).item() a = torch.rand(num_rows, dtype=dtype, device=device, requires_grad=True) assert a.is_contiguous() assert a.dtype == dtype num_indexes = torch.randint(1, 200000, size=(1,)).item() b = torch.randint(-1, num_rows, size=(num_indexes,), dtype=torch.int32, device=device) c = k2.index_select(a, b) assert c.dtype == a.dtype c.sum().backward() new_a = a.detach().requires_grad_(True) padded_a = torch.cat([torch.tensor([0]).to(new_a), new_a]) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) expected.sum().backward() assert torch.allclose(c, expected) assert torch.allclose(a.grad, new_a.grad)
def test_1d(self): a = torch.tensor([ 10, # 0 -1, # 1 100, # 2 0, # 3 3, # 4 9, # 5 12, # 6 ]).to(torch.int32) b = torch.tensor([0, -1, 0, 2, 1, 3, 6, -1, 5, 6, 0, 2]).to(torch.int32) c = k2.index_select(a, b) assert c.dtype == a.dtype assert c.numel() == b.numel() padded_a = torch.cat([torch.tensor([0]).to(a), a]) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) assert torch.allclose(c, expected) if torch.cuda.is_available(): device = torch.device('cuda', 0) a = a.to(device) b = b.to(device) c = k2.index_select(a, b) assert c.dtype == a.dtype assert c.is_cuda assert c.numel() == b.numel() assert torch.allclose(c, expected.to(c)) # now for float32 a = a.to(torch.float32).requires_grad_(True) c = k2.index_select(a, b) assert c.dtype == a.dtype c.sum().backward() new_a = a.detach().requires_grad_(True) padded_a = torch.cat([torch.tensor([0]).to(new_a), new_a]) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) expected.sum().backward() assert torch.allclose(a.grad, new_a.grad) # now for cpu a.grad = None c = k2.index_select(a.cpu(), b.cpu()) assert c.dtype == a.dtype c.sum().backward() assert torch.allclose(a.grad, new_a.grad.to(a.grad))
def test_2d(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: for dtype in [torch.int32, torch.int64]: num_rows = torch.randint(1, 2000, size=(1,)).item() num_cols = torch.randint(1, 2000, size=(1,)).item() a = torch.randint(-1000, 1000, size=(num_rows, num_cols), dtype=dtype, device=device) b = torch.randint(-1, num_rows, size=(10000,), dtype=torch.int32, device=device) assert a.is_contiguous() assert a.dtype == dtype c = k2.index_select(a, b) assert c.dtype == a.dtype assert c.device == a.device assert c.shape[1] == a.shape[1] assert c.shape[0] == b.shape[0] padded_a = torch.cat([torch.zeros(1, a.shape[1]).to(a), a]) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) assert torch.allclose(c, expected) for dtype in [torch.float32, torch.float64]: src = a.to(dtype).requires_grad_(True) assert src.is_contiguous() assert b.is_contiguous() c = k2.index_select(src, b) assert c.dtype == src.dtype c.sum().backward() new_src = src.detach().requires_grad_(True) padded_src = torch.cat( [torch.zeros(1, src.shape[1]).to(new_src), new_src]) expected = padded_src.index_select(0, (b + 1).to(torch.int64)) expected.sum().backward() assert torch.allclose(c, expected) assert torch.allclose(src.grad, new_src.grad)
def test_1d_non_contiguous(self): a = torch.arange(20).to(torch.int32)[::2] b = torch.tensor([ -1, -2, -1, -2, 1, -2, 2, -2, 0, -2, 1, -2, 5, -2, 3, -2, 9, -2, -1, -2, 8, -2, 7, -2, 7, -2, 6, -2 ]).to(torch.int32)[::2] padded_a = torch.cat([torch.tensor([0]).to(a), a]) assert a.is_contiguous() is False assert b.is_contiguous() is False c = k2.index_select(a, b) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) assert torch.allclose(c, expected) a = a.to(torch.float32).requires_grad_(True) c = k2.index_select(a, b) new_a = a.detach().clone().requires_grad_(True) padded_a = torch.cat([torch.tensor([0]).to(a), new_a]) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) assert torch.allclose(c, expected.to(c)) c.sum().backward() expected.sum().backward() assert torch.allclose(a.grad, new_a.grad) # now for cuda if torch.cuda.is_available(): device = torch.device('cuda', 0) b = b.to(device) a.requires_grad_(False) a = a.to(device).requires_grad_(True) c = k2.index_select(a, b) new_a.requires_grad_(False) new_a = new_a.to(device).requires_grad_(True) padded_a = torch.cat([torch.tensor([0]).to(a), new_a]) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) assert torch.allclose(c, expected.to(c)) c.sum().backward() expected.sum().backward() assert torch.allclose(a.grad, new_a.grad)
def test_2d_non_contiguous(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: num_rows = torch.randint(1, 2000, size=(1,)).item() num_cols = torch.randint(1, 2000, size=(1,)).item() stride = torch.randint(2, num_rows // 10, size=(1,)).item() a = torch.randint(-1000, 1000, size=(num_rows, num_cols), dtype=torch.int32, device=device).contiguous() a = a[::stride] num_rows = a.shape[0] b = torch.randint(-1, num_rows, size=(10000,), dtype=torch.int32, device=device) assert a.is_contiguous() is False c = k2.index_select(a, b) assert c.dtype == a.dtype assert c.device == a.device assert c.shape[1] == a.shape[1] assert c.shape[0] == b.shape[0] padded_a = torch.cat([torch.zeros(1, a.shape[1]).to(a), a]) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) # now for float32 a = a.to(torch.float32).requires_grad_(True) assert a.is_contiguous() assert b.is_contiguous() c = k2.index_select(a, b) assert c.dtype == a.dtype c.sum().backward() new_a = a.detach().requires_grad_(True) padded_a = torch.cat([torch.zeros(1, a.shape[1]).to(new_a), new_a]) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) expected.sum().backward() assert torch.allclose(a.grad, new_a.grad)
def test_1d_empty_index(self): for device in self.devices: for dtype in [torch.int32, torch.int64]: num_rows = torch.randint(0, 10, size=(1, )).item() a = torch.randint(-1000, 1000, size=(num_rows, ), dtype=dtype, device=device) assert a.is_contiguous() assert a.dtype == dtype b = torch.empty(0, dtype=torch.int32, device=device) c = k2.index_select(a, b) assert c.dtype == a.dtype assert c.numel() == b.numel() for dtype in [torch.float32, torch.float64]: num_rows = torch.randint(0, 10, size=(1, )).item() a = torch.rand(num_rows, dtype=dtype, device=device, requires_grad=True) assert a.is_contiguous() assert a.dtype == dtype b = torch.empty(0, dtype=torch.int32, device=device) c = k2.index_select(a, b) assert c.dtype == a.dtype c.sum().backward() new_a = a.detach().requires_grad_(True) padded_a = torch.cat([torch.tensor([0]).to(new_a), new_a]) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) expected.sum().backward() assert torch.allclose(c, expected) assert torch.allclose(a.grad, new_a.grad)
def test_sort_sublist_descending(self): for device in self.devices: for dtype in self.dtypes: src = k2.RaggedTensor([[3, 2], [], [1, 5, 2]], dtype).to(device) src_clone = src.clone() new2old = src.sort_(descending=True, need_new2old_indexes=True) sorted_src = k2.RaggedTensor([[3, 2], [], [5, 2, 1]], dtype=dtype).to(device) expected_new2old = torch.tensor([0, 1, 3, 4, 2], device=device, dtype=torch.int32) assert src == sorted_src assert torch.all(torch.eq(new2old, expected_new2old)) expected_sorted = k2.index_select(src_clone.values, new2old) sorted = src.values assert torch.all(torch.eq(expected_sorted, sorted))
def _intersect_device( a_fsas: k2.Fsa, b_fsas: k2.Fsa, b_to_a_map: torch.Tensor, sorted_match_a: bool, batch_size: int = 500, ): """Wrap k2.intersect_device This is a wrapper of k2.intersect_device and its purpose is to split b_fsas into several batches and process each batch separately to avoid CUDA OOM error. The arguments and return value of this function are the same as k2.intersect_device. NOTE: You can decrease batch_size in case of CUDA out of memory error. """ num_fsas = b_fsas.shape[0] if num_fsas <= batch_size: return k2.intersect_device( a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a ) num_batches = int(math.ceil(float(num_fsas) / batch_size)) splits = [] for i in range(num_batches): start = i * batch_size end = min(start + batch_size, num_fsas) splits.append((start, end)) ans = [] for start, end in splits: indexes = torch.arange(start, end).to(b_to_a_map) fsas = k2.index_fsa(b_fsas, indexes) b_to_a = k2.index_select(b_to_a_map, indexes) path_lats = k2.intersect_device( a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a ) ans.append(path_lats) return k2.cat(ans)
def my_func(src: torch.Tensor, index: torch.Tensor) -> torch.Tensor: return k2.index_select(src.to(torch.float32), index)
def test_2d_non_contiguous(self): for device in self.devices: for dtype in [torch.int32, torch.int64]: num_rows = torch.randint(20, 2000, size=(1, )).item() num_cols = torch.randint(1, 2000, size=(1, )).item() stride = torch.randint(2, num_rows // 10 + 1, size=(1, )).item() a = torch.randint(-1000, 1000, size=(num_rows, num_cols), dtype=dtype, device=device).contiguous() a = a[::stride] num_rows = a.shape[0] b = torch.randint(-1, num_rows, size=(10000, ), dtype=torch.int32, device=device) assert a.is_contiguous() is False assert a.dtype == dtype c = k2.index_select(a, b) assert c.dtype == a.dtype assert c.device == a.device assert c.shape[1] == a.shape[1] assert c.shape[0] == b.shape[0] padded_a = torch.cat([torch.zeros(1, a.shape[1]).to(a), a]) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) assert torch.allclose(c, expected) for dtype in [torch.float32, torch.float64]: num_rows = torch.randint(20, 2000, size=(1, )).item() num_cols = torch.randint(1, 2000, size=(1, )).item() stride = torch.randint(2, num_rows // 10 + 1, size=(1, )).item() a = torch.randint(-1000, 1000, size=(num_rows, num_cols), dtype=dtype, device=device).contiguous() a = a[::stride] num_rows = a.shape[0] b = torch.randint(-1, num_rows, size=(10000, ), dtype=torch.int32, device=device) assert a.is_contiguous() is False assert a.dtype == dtype a.requires_grad_(True) c = k2.index_select(a, b) assert c.dtype == a.dtype assert c.device == a.device assert c.shape[1] == a.shape[1] assert c.shape[0] == b.shape[0] c.sum().backward() new_a = a.detach().requires_grad_(True) padded_a = torch.cat( [torch.zeros(1, a.shape[1]).to(new_a), new_a]) expected = padded_a.index_select(0, (b + 1).to(torch.int64)) expected.sum().backward() assert torch.allclose(c, expected) assert torch.allclose(a.grad, new_a.grad)
def forward( self, log_probs: torch.Tensor, targets: torch.Tensor, input_lengths: torch.Tensor, target_lengths: torch.Tensor, ) -> torch.Tensor: assert self.graph_compiler is not None boosted = self.boost_coeff != 0.0 if self.blank != 0: # rearrange log_probs to put blank at the first place # and shift targets to emulate blank = 0 log_probs, targets = make_blank_first(self.blank, log_probs, targets) supervisions, order = create_supervision(input_lengths) order = order.long() targets = targets[order] target_lengths = target_lengths[order] if log_probs.device != self.graph_compiler.device: self.graph_compiler.to(log_probs.device) num_graphs, den_graph = self.graph_compiler.compile( targets + 1 if self.pad_fsavec else targets, target_lengths ) dense_fsa_vec = ( prep_padded_densefsavec(log_probs, supervisions) if self.pad_fsavec else k2.DenseFsaVec(log_probs, supervisions) ) num_tot_scores, den_tot_scores, num_lats, den_lats = self.intersect_calc_scores( dense_fsa_vec, num_graphs, den_graph, boosted ) tot_scores = num_tot_scores - den_tot_scores mmi_tot_scores, mmi_valid_mask = get_tot_objf_and_finite_mask(tot_scores, self.reduction) if boosted: assert num_lats is not None and den_lats is not None size = ( dense_fsa_vec.dim0(), dense_fsa_vec.scores.shape[0], dense_fsa_vec.scores.shape[1] - 1, ) row_ids = dense_fsa_vec.dense_fsa_vec.shape().row_ids(1) num_sparse = create_sparse_wrapped( indices=[k2.index_select(row_ids, num_lats.seqframe_idx), num_lats.seqframe_idx, num_lats.phones,], values=num_lats.get_arc_post(False, True).exp(), size=size, min_col_index=0, ) den_sparse = create_sparse_wrapped( indices=[k2.index_select(row_ids, den_lats.seqframe_idx), den_lats.seqframe_idx, den_lats.phones,], values=den_lats.get_arc_post(False, True).exp(), size=size, min_col_index=0, ) # NOTE: Due to limited support of PyTorch's autograd for sparse tensors, # we cannot use (num_sparse - den_sparse) here # TODO (alaptev): propose sparse_abs to k2 acc_loss = torch.sparse.sum(sparse_abs((num_sparse + (-den_sparse)).coalesce()), (1, 2)).to_dense() acc_tot_scores, acc_valid_mask = get_tot_objf_and_finite_mask(acc_loss, self.reduction) valid_mask = mmi_valid_mask & acc_valid_mask total_loss = self.boost_coeff * acc_tot_scores[valid_mask] - mmi_tot_scores[valid_mask] else: valid_mask = mmi_valid_mask total_loss = -mmi_tot_scores[mmi_valid_mask] return total_loss, valid_mask