def test_log_single_fsa(self): s = ''' 0 1 1 0.1 0 2 2 0.2 1 2 3 0.3 1 3 4 0.4 2 3 5 0.5 3 4 -1 0 4 ''' devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: fsa = k2.Fsa.from_str(s).to(device) fsa.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa]) log_like = k2.get_tot_scores(fsa_vec, log_semiring=True, use_float_scores=True) assert log_like.dtype == torch.float32 # The expected_log_like is computed using gtn. # See https://bit.ly/3oUiRx9 expected_log_like = torch.tensor([1.8119014501571655]).to(device) assert torch.allclose(log_like, expected_log_like) # The expected_grad is computed using gtn. # See https://bit.ly/3oUiRx9 expected_grad = torch.tensor([ 0.6710670590400696, 0.32893291115760803, 0.4017595648765564, 0.2693074941635132, 0.7306925058364868, 1.0 ]).to(device) scale = -1.75 (scale * log_like).sum().backward() assert torch.allclose(fsa.scores.grad, scale * expected_grad) # now for double fsa.scores.grad = None log_like = k2.get_tot_scores(fsa_vec, log_semiring=True, use_float_scores=False) assert log_like.dtype == torch.float64 expected_log_like = expected_log_like.to(torch.float64) assert torch.allclose(log_like, expected_log_like) scale = 10.25 (scale * log_like).sum().backward() assert torch.allclose(fsa.scores.grad, scale * expected_grad)
def test_tropical_single_fsa(self): # best path arc indexes are: 1, 3, 5, 10 s = ''' 0 4 1 1 0 1 1 1 1 2 1 2 1 3 1 3 2 7 1 4 3 7 1 5 4 6 1 2 4 8 1 3 5 9 -1 4 6 9 -1 3 7 9 -1 5 8 9 -1 6 9 ''' devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda')) for device in devices: fsa = k2.Fsa.from_str(s).to(device) fsa = k2.create_fsa_vec([fsa]) fsa.requires_grad_(True) log_like = k2.get_tot_scores(fsa, log_semiring=False, use_float_scores=True) assert log_like == 14 assert log_like.dtype == torch.float32 scale = -10 (scale * log_like).sum().backward() expected = torch.zeros(len(fsa.scores)).to(device) expected[torch.tensor([1, 3, 5, 10])] = 1 assert torch.allclose(fsa.scores.grad, scale * expected) # now for double fsa.scores.grad = None log_like = k2.get_tot_scores(fsa, log_semiring=False, use_float_scores=False) assert log_like == 14 assert log_like.dtype == torch.float64 scale = -1.25 (scale * log_like).sum().backward() assert torch.allclose(fsa.scores.grad, scale * expected)
def test_two_fsas(self): s1 = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' s2 = ''' 0 1 1 1.0 1 2 2 2.0 2 3 -1 3.0 3 ''' fsa1 = k2.Fsa.from_str(s1) fsa2 = k2.Fsa.from_str(s2) fsa1.requires_grad_(True) fsa2.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa1, fsa2]) log_prob = torch.tensor( [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]], [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]], dtype=torch.float32, requires_grad=True) supervision_segments = torch.tensor([[0, 0, 3], [1, 0, 2]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) out_fsa = k2.intersect_dense(fsa_vec, dense_fsa_vec, output_beam=100000) assert out_fsa.shape == (2, None, None), 'There should be two FSAs!' scores = k2.get_tot_scores(out_fsa, log_semiring=False, use_float_scores=True) scores.sum().backward() # `expected` results are computed using gtn. # See https://bit.ly/3oYObeb # expected_scores_out_fsa = torch.tensor( # [1.2, 2.06, 3.0, 1.2, 50.5, 2.0, 3.0]) expected_grad_fsa1 = torch.tensor([1.0, 1.0, 1.0, 1.0]) expected_grad_fsa2 = torch.tensor([1.0, 1.0, 1.0]) print("fsa2 is ", fsa2.__str__()) # TODO(dan):: fix this.. # expected_grad_log_prob = torch.tensor([ # 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0, 0, 0, 0.0, 1.0, 0.0, 0.0, 1.0, # 0.0, 0.0, 0.0, 1.0 # ]).reshape_as(log_prob) # assert torch.allclose(out_fsa.scores, expected_scores_out_fsa) assert torch.allclose(expected_grad_fsa1, fsa1.scores.grad) assert torch.allclose(expected_grad_fsa2, fsa2.scores.grad)
def test_simple(self): s = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' fsa = k2.Fsa.from_str(s) fsa.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa]) log_prob = torch.tensor([[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06]]], dtype=torch.float32, requires_grad=True) supervision_segments = torch.tensor([[0, 0, 2]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) out_fsa = k2.intersect_dense(fsa_vec, dense_fsa_vec, output_beam=100000) scores = k2.get_tot_scores(out_fsa, log_semiring=False, use_float_scores=True) scores.sum().backward() # `expected` results are computed using gtn. # See https://bit.ly/3oYObeb expected_scores_out_fsa = torch.tensor([1.2, 2.06, 3.0]) expected_grad_fsa = torch.tensor([1.0, 0.0, 1.0, 1.0]) expected_grad_log_prob = torch.tensor([0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).reshape_as(log_prob) assert torch.allclose(out_fsa.scores, expected_scores_out_fsa) assert torch.allclose(expected_grad_fsa, fsa.scores.grad) assert torch.allclose(expected_grad_log_prob, log_prob.grad)
def forward(self, log_probs: torch.Tensor, targets: torch.Tensor, input_lengths: torch.Tensor, target_lengths: torch.Tensor) -> torch.Tensor: log_probs = log_probs.permute(1, 0, 2).cpu( ) # now log_probs is [N, T, C] batchSize x seqLength x alphabet_size supervision_segments = torch.stack( (torch.tensor(range(input_lengths.shape[0])), torch.zeros(input_lengths.shape[0]), input_lengths), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] dense_fsa_vec = k2.DenseFsaVec(log_probs, supervision_segments) decoding_graph = self.graph_compiler.compile(targets.cpu(), target_lengths) decoding_graph = k2.index(decoding_graph, indices.to(torch.int32)).to(log_probs.device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) tot_scores = k2.get_tot_scores(target_graph, log_semiring=True, use_double_scores=True) (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) return -tot_score
def test_autograd(self): s0 = ''' 0 1 1 0.1 0 2 2 0.2 1 3 -1 0.3 1 2 2 0.4 2 3 -1 0.5 3 ''' s1 = ''' 0 2 -1 0.6 0 1 1 0.7 1 2 -1 0.8 2 ''' s2 = ''' 0 1 1 1.1 1 2 -1 1.2 2 ''' devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: fsa0 = k2.Fsa.from_str(s0).to(device).requires_grad_(True) fsa1 = k2.Fsa.from_str(s1).to(device).requires_grad_(True) fsa2 = k2.Fsa.from_str(s2).to(device).requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa0, fsa1, fsa2]) fsa = k2.union(fsa_vec) fsa_vec = k2.create_fsa_vec([fsa]) log_like = k2.get_tot_scores(fsa_vec, log_semiring=True, use_double_scores=False) # expected log_like and gradients are computed using gtn. # See https://bit.ly/35uVaUv log_like.backward() expected_log_like = torch.tensor([3.1136]).to(log_like) assert torch.allclose(log_like, expected_log_like) expected_grad_fsa0 = torch.tensor([ 0.18710044026374817, 0.08949274569749832, 0.06629786640405655, 0.12080258131027222, 0.21029533445835114 ]).to(device) expected_grad_fsa1 = torch.tensor([ 0.08097638934850693, 0.19916976988315582, 0.19916976988315582 ]).to(device) expected_grad_fsa2 = torch.tensor( [0.4432605803012848, 0.4432605803012848]).to(device) assert torch.allclose(fsa0.grad, expected_grad_fsa0) assert torch.allclose(fsa1.grad, expected_grad_fsa1) assert torch.allclose(fsa2.grad, expected_grad_fsa2)
def test_tropical_single_fsa(self): # best path arc indexes are: 1, 3, 5, 10 s = ''' 0 4 1 1 0 1 1 1 1 2 1 2 1 3 1 3 2 7 1 4 3 7 1 5 4 6 1 2 4 8 1 3 5 9 -1 4 6 9 -1 3 7 9 -1 5 8 9 -1 6 9 ''' fsa = k2.Fsa.from_str(s) fsa = k2.create_fsa_vec([fsa]) fsa.requires_grad_(True) log_like = k2.get_tot_scores(fsa, log_semiring=False, use_float_scores=True) assert log_like == 14 assert log_like.dtype == torch.float32 log_like.sum().backward() expected = torch.zeros(len(fsa.scores)) expected[torch.tensor([1, 3, 5, 10])] = 1 assert torch.allclose(fsa.scores.grad, expected) # now for double fsa.scores.grad = None log_like = k2.get_tot_scores(fsa, log_semiring=False, use_float_scores=False) assert log_like == 14 assert log_like.dtype == torch.float64 log_like.sum().backward() assert torch.allclose(fsa.scores.grad, expected)
def test_two_fsas_long_pruned(self): # as test_two_fsas_long in intersect_dense_test.py, # but with pruned intersection s1 = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' s2 = ''' 0 1 1 1.0 1 2 2 2.0 2 3 -1 3.0 3 ''' devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: fsa1 = k2.Fsa.from_str(s1) fsa2 = k2.Fsa.from_str(s2) fsa1.requires_grad_(True) fsa2.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa1, fsa2]) log_prob = torch.rand((2, 100, 3), dtype=torch.float32, device=device, requires_grad=True) supervision_segments = torch.tensor([[0, 0, 95], [1, 20, 50]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) fsa_vec = fsa_vec.to(device) out_fsa = k2.intersect_dense_pruned(fsa_vec, dense_fsa_vec, search_beam=100, output_beam=100, min_active_states=1, max_active_states=10) assert out_fsa.shape == (2, None, None), 'There should be two FSAs!' scores = k2.get_tot_scores(out_fsa, log_semiring=False, use_double_scores=False) scores.sum().backward()
def get_objf(batch, model, device, L, symbols, training, optimizer=None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], supervisions['start_frame'], supervisions['num_frames']), 1).to(torch.int32) texts = supervisions['text'] assert feature.ndim == 3 #print(feature.shape) #print(supervision_segments[:, 1] + supervision_segments[:, 2]) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] feature = feature.to(device) if training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] # TODO(haowen): create decoding graph at the beginning of training decoding_graph = create_decoding_graph(texts, L, symbols) decoding_graph.to_(device) decoding_graph.scores.requires_grad_(False) #print(nnet_output.shape) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) #dense_fsa_vec.scores.requires_grad_(True) assert decoding_graph.is_cuda() assert decoding_graph.device == device assert nnet_output.device == device #print(nnet_output.get_device()) print(decoding_graph.arcs) print(dense_fsa_vec.dense_fsa_vec) target_graph = k2.intersect_dense_pruned(decoding_graph, dense_fsa_vec, 10, 10000, 0) tot_scores = -k2.get_tot_scores(target_graph, True, False).sum() if training: optimizer.zero_grad() tot_scores.backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() objf = tot_scores.detach().cpu() total_objf = objf.item() total_frames = nnet_output.shape[0] return total_objf, total_frames
def test_two_fsas_long(self): # as test_two_fsas, but generate long DenseFsaVec for easier profiling. s1 = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' s2 = ''' 0 1 1 1.0 1 2 2 2.0 2 3 -1 3.0 3 ''' devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: fsa1 = k2.Fsa.from_str(s1) fsa2 = k2.Fsa.from_str(s2) fsa1.requires_grad_(True) fsa2.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa1, fsa2]) log_prob = torch.rand((2, 500, 3), dtype=torch.float32, device=device, requires_grad=True) supervision_segments = torch.tensor([[0, 0, 490], [1, 0, 300]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) fsa_vec = fsa_vec.to(device) out_fsa = k2.intersect_dense(fsa_vec, dense_fsa_vec, output_beam=100000) assert out_fsa.shape == (2, None, None), 'There should be two FSAs!' scores = k2.get_tot_scores(out_fsa, log_semiring=False, use_float_scores=True) scores.sum().backward()
def test_two_dense(self): s = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' fsa = k2.Fsa.from_str(s) fsa.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa]) log_prob = torch.tensor( [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]], [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]], dtype=torch.float32, requires_grad=True) supervision_segments = torch.tensor([[0, 0, 2], [1, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) out_fsa = k2.intersect_dense_pruned(fsa_vec, dense_fsa_vec, search_beam=100000, output_beam=100000, min_active_states=0, max_active_states=10000) assert out_fsa.shape == (2, None, None), 'There should be two FSAs!' scores = k2.get_tot_scores(out_fsa, log_semiring=False, use_double_scores=False) scores.sum().backward() # `expected` results are computed using gtn. # See https://bit.ly/3oYObeb expected_scores_out_fsa = torch.tensor( [1.2, 2.06, 3.0, 1.2, 50.5, 2.0, 3.0]) expected_grad_fsa = torch.tensor([2.0, 1.0, 2.0, 2.0]) expected_grad_log_prob = torch.tensor([ 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0, 0, 0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0 ]).reshape_as(log_prob) assert torch.allclose(out_fsa.scores, expected_scores_out_fsa) assert torch.allclose(expected_grad_fsa, fsa.scores.grad) assert torch.allclose(expected_grad_log_prob, log_prob.grad)
def test_compose(self): s = ''' 0 1 11 1 1.0 0 2 12 2 2.5 1 3 -1 -1 0 2 3 -1 -1 2.5 3 ''' a_fsa = k2.Fsa.from_str(s).requires_grad_(True) s = ''' 0 1 1 1 1.0 0 2 2 3 3.0 1 2 3 2 2.5 2 3 -1 -1 2.0 3 ''' b_fsa = k2.Fsa.from_str(s).requires_grad_(True) ans = k2.compose(a_fsa, b_fsa, inner_labels='inner') ans = k2.connect(ans) # Convert a single FSA to a FsaVec. # It will retain `requires_grad_` of `ans`. ans.__dict__['arcs'] = _k2.create_fsa_vec([ans.arcs]) scores = k2.get_tot_scores(ans, log_semiring=True, use_double_scores=False) # The reference values for `scores`, `a_fsa.grad` and `b_fsa.grad` # are computed using GTN. # See https://bit.ly/3heLAJq assert scores.item() == 10 scores.backward() assert torch.allclose(a_fsa.grad, torch.tensor([0., 1., 0., 1.])) assert torch.allclose(b_fsa.grad, torch.tensor([0., 1., 0., 1.])) print(ans)
def get_objf(batch: Dict, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, training: bool, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] decoding_graph = graph_compiler.compile(texts).to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert decoding_graph.is_cuda() assert decoding_graph.device == device assert nnet_output.device == device # TODO(haowen): with a small `beam`, we may get empty `target_graph`, # thus `tot_scores` will be `inf`. Definitely we need to handle this later. target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) tot_scores = k2.get_tot_scores(target_graph, log_semiring=True, use_double_scores=True) (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if training: optimizer.zero_grad() (-tot_score).backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def test_tropical_multiple_fsas(self): # best path: # states: 0 -> 1 -> 3 -> 7 -> 9 # arcs: 1 -> 3 -> 5 -> 10 # scores: 1 + 3 + 5 + 5 = 14 s1 = ''' 0 4 1 1 0 1 1 1 1 2 1 2 1 3 1 3 2 7 1 4 3 7 1 5 4 6 1 2 4 8 1 3 5 9 -1 4 6 9 -1 3 7 9 -1 5 8 9 -1 6 9 ''' # best path: # states: 0 -> 2 -> 3 -> 4 -> 5 # arcs: 1 -> 4 -> 5 -> 7 # scores: 6 + 4 + 3 + 0 = 13 s2 = ''' 0 1 1 1 0 2 2 6 1 2 3 3 1 3 4 2 2 3 5 4 3 4 6 3 3 5 -1 2 4 5 -1 0 5 ''' # best path: # states: 0 -> 2 -> 3 # arcs: 1 -> 3 # scores: 100 + 5.5 = 105.5 s3 = ''' 0 1 1 10 0 2 2 100 1 3 -1 3.5 2 3 -1 5.5 3 ''' cpu_device = torch.device('cpu') cuda_device = torch.device('cuda', 0) for device in (cpu_device, cuda_device): fsa1 = k2.Fsa.from_str(s1).to(device) fsa2 = k2.Fsa.from_str(s2).to(device) fsa3 = k2.Fsa.from_str(s3).to(device) fsa1.requires_grad_(True) fsa2.requires_grad_(True) fsa3.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa1, fsa2, fsa3]) log_like = k2.get_tot_scores(fsa_vec, log_semiring=False, use_float_scores=True) assert log_like.dtype == torch.float32 expected_log_like = torch.tensor([14, 13, 105.5]).to(device) assert torch.allclose(log_like, expected_log_like) scale = torch.tensor([-10, -20, -30.]).to(log_like) (log_like * scale).sum().backward() fsa1_best_arc_indexes = torch.tensor([1, 3, 5, 10]).to(device) assert torch.allclose( fsa1.scores.grad[fsa1_best_arc_indexes], scale[0] * torch.ones(4, dtype=torch.float32).to(device)) assert fsa1.scores.grad.sum() == 4 * scale[0] fsa2_best_arc_indexes = torch.tensor([1, 4, 5, 7]).to(device) assert torch.allclose( fsa2.scores.grad[fsa2_best_arc_indexes], scale[1] * torch.ones(4, dtype=torch.float32).to(device)) assert fsa2.scores.grad.sum() == 4 * scale[1] fsa3_best_arc_indexes = torch.tensor([1, 3]).to(device) assert torch.allclose( fsa3.scores.grad[fsa3_best_arc_indexes], scale[2] * torch.ones(2, dtype=torch.float32).to(device)) assert fsa3.scores.grad.sum() == 2 * scale[2] # now for double fsa1.scores.grad = None fsa2.scores.grad = None fsa3.scores.grad = None log_like = k2.get_tot_scores(fsa_vec, log_semiring=False, use_float_scores=False) assert log_like.dtype == torch.float64 expected_log_like = expected_log_like.to(torch.float64) assert torch.allclose(log_like, expected_log_like) scale = torch.tensor([-1.25, -2.5, 3.5]).to(log_like) (scale * log_like).sum().backward() assert torch.allclose( fsa1.scores.grad[fsa1_best_arc_indexes], scale[0] * torch.ones(4, dtype=torch.float32).to(device)) assert fsa1.scores.grad.sum() == 4 * scale[0] assert torch.allclose( fsa2.scores.grad[fsa2_best_arc_indexes], scale[1] * torch.ones(4, dtype=torch.float32).to(device)) assert fsa2.scores.grad.sum() == 4 * scale[1] assert torch.allclose( fsa3.scores.grad[fsa3_best_arc_indexes], scale[2] * torch.ones(2, dtype=torch.float32).to(device)) assert fsa3.scores.grad.sum() == 2 * scale[2]
def test_log_multiple_fsas(self): s1 = ''' 0 1 1 0.1 0 2 2 0.2 1 2 3 0.3 1 3 4 0.4 2 3 5 0.5 3 4 -1 0 4 ''' s2 = ''' 0 3 3 0.1 0 1 1 0.2 0 2 2 0.3 1 2 2 0.4 1 3 3 0.5 2 3 3 0.6 2 4 4 0.7 3 4 4 0.8 3 5 -1 0.9 4 5 -1 1.0 5 ''' cpu_device = torch.device('cpu') cuda_device = torch.device('cuda', 0) for device in (cpu_device, cuda_device): fsa1 = k2.Fsa.from_str(s1).to(device) fsa2 = k2.Fsa.from_str(s2).to(device) fsa1.requires_grad_(True) fsa2.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa1, fsa2]) log_like = k2.get_tot_scores(fsa_vec, log_semiring=True, use_float_scores=True) assert log_like.dtype == torch.float32 # The expected_log_likes are computed using gtn. # See https://bit.ly/3oUiRx9 expected_log_like = torch.tensor( [1.8119014501571655, 4.533502578735352]).to(device) assert torch.allclose(log_like, expected_log_like) scale = torch.tensor([1.25, -5.25]).to(log_like) (scale * log_like).sum().backward() # The expected_grads are computed using gtn. # See https://bit.ly/3oUiRx9 expected_grad_fsa1 = torch.tensor([ 0.6710670590400696, 0.32893291115760803, 0.4017595648765564, 0.2693074941635132, 0.7306925058364868, 1.0 ]).to(device) expected_grad_fsa2 = torch.tensor([ 0.10102888941764832, 0.5947467088699341, 0.3042244613170624, 0.410660058259964, 0.1840866357088089, 0.5283515453338623, 0.18653297424316406, 0.5783339142799377, 0.2351330667734146, 0.764866828918457 ]).to(device) assert torch.allclose(fsa1.scores.grad, scale[0] * expected_grad_fsa1) assert torch.allclose(fsa2.scores.grad, scale[1] * expected_grad_fsa2) # now for double fsa1.scores.grad = None fsa2.scores.grad = None log_like = k2.get_tot_scores(fsa_vec, log_semiring=True, use_float_scores=False) assert log_like.dtype == torch.float64 expected_log_like = expected_log_like.to(torch.float64) assert torch.allclose(log_like, expected_log_like) scale = torch.tensor([-10.25, 8.25]).to(log_like) (scale * log_like).sum().backward() assert torch.allclose(fsa1.scores.grad, scale[0] * expected_grad_fsa1) assert torch.allclose(fsa2.scores.grad, scale[1] * expected_grad_fsa2)