def test_no_repeated(self): # standard ctc topo and modified ctc topo # should be equivalent if there are no # repeated neighboring symbols in the transcript max_token = 3 standard = k2.ctc_topo(max_token, modified=False) modified = k2.ctc_topo(max_token, modified=True) transcript = k2.linear_fsa([1, 2, 3]) standard_graph = k2.compose(standard, transcript) modified_graph = k2.compose(modified, transcript) input1 = k2.linear_fsa([1, 1, 1, 0, 0, 2, 2, 3, 3]) input2 = k2.linear_fsa([1, 1, 0, 0, 2, 2, 0, 3, 3]) inputs = [input1, input2] for i in inputs: lattice1 = k2.intersect(standard_graph, i, treat_epsilons_specially=False) lattice2 = k2.intersect(modified_graph, i, treat_epsilons_specially=False) lattice1 = k2.connect(lattice1) lattice2 = k2.connect(lattice2) aux_labels1 = lattice1.aux_labels[lattice1.aux_labels != 0] aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0] aux_labels1 = aux_labels1[:-1] # remove -1 aux_labels2 = aux_labels2[:-1] assert torch.all(torch.eq(aux_labels1, aux_labels2)) assert torch.all(torch.eq(aux_labels2, torch.tensor([1, 2, 3])))
def test_with_repeated(self): max_token = 2 standard = k2.ctc_topo(max_token, modified=False) modified = k2.ctc_topo(max_token, modified=True) transcript = k2.linear_fsa([1, 2, 2]) standard_graph = k2.compose(standard, transcript) modified_graph = k2.compose(modified, transcript) # There is a blank separating 2 in the input # so standard and modified ctc topo should be equivalent input = k2.linear_fsa([1, 1, 2, 2, 0, 2, 2, 0, 0]) lattice1 = k2.intersect(standard_graph, input, treat_epsilons_specially=False) lattice2 = k2.intersect(modified_graph, input, treat_epsilons_specially=False) lattice1 = k2.connect(lattice1) lattice2 = k2.connect(lattice2) aux_labels1 = lattice1.aux_labels[lattice1.aux_labels != 0] aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0] aux_labels1 = aux_labels1[:-1] # remove -1 aux_labels2 = aux_labels2[:-1] assert torch.all(torch.eq(aux_labels1, aux_labels2)) assert torch.all(torch.eq(aux_labels1, torch.tensor([1, 2, 2]))) # There are no blanks separating 2 in the input. # The standard ctc topo requires that there must be a blank # separating 2, so lattice1 in the following is empty input = k2.linear_fsa([1, 1, 2, 2, 0, 0]) lattice1 = k2.intersect(standard_graph, input, treat_epsilons_specially=False) lattice2 = k2.intersect(modified_graph, input, treat_epsilons_specially=False) lattice1 = k2.connect(lattice1) lattice2 = k2.connect(lattice2) assert lattice1.num_arcs == 0 # Since there are two 2s in the input and there are also two 2s # in the transcript, the final output contains only one path. # If there were more than two 2s in the input, the output # would contain more than one path aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0] aux_labels2 = aux_labels2[:-1] assert torch.all(torch.eq(aux_labels1, torch.tensor([1, 2, 2])))
def test_case1(self): for device in self.devices: # suppose we have four symbols: <blk>, a, b, c, d torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]).to(device) k2_activation = torch_activation.detach().clone() # (T, N, C) torch_activation = torch_activation.reshape( 1, 1, -1).requires_grad_(True) # (N, T, C) k2_activation = k2_activation.reshape(1, 1, -1).requires_grad_(True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its label is `a` targets = torch.tensor([1]).to(device) input_lengths = torch.tensor([1]).to(device) target_lengths = torch.tensor([1]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='mean') assert torch.allclose(torch_loss, torch.tensor([1.6094379425049]).to(device)) # (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(k2_activation, dim=-1) supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo = k2.ctc_topo(4) linear_fsa = k2.linear_fsa([1]) decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device) k2_loss = k2.ctc_loss(decoding_graph, dense_fsa_vec, reduction='mean', target_lengths=target_lengths) assert torch.allclose(torch_loss, k2_loss) torch_loss.backward() k2_loss.backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def test_case3(self): for device in self.devices: # (T, N, C) torch_activation = torch.tensor([[ [-5, -4, -3, -2, -1], [-10, -9, -8, -7, -6], [-15, -14, -13, -12, -11.], ]]).permute(1, 0, 2).to(device).requires_grad_(True) torch_activation = torch_activation.to(torch.float32) torch_activation.requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its labels are `b,c` targets = torch.tensor([2, 3]).to(device) input_lengths = torch.tensor([3]).to(device) target_lengths = torch.tensor([2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='mean') act = k2_activation.permute(1, 0, 2) # (T, N, C) -> (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1) supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo = k2.ctc_topo(4) linear_fsa = k2.linear_fsa([2, 3]) decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device) k2_loss = k2.ctc_loss(decoding_graph, dense_fsa_vec, reduction='mean', target_lengths=target_lengths) expected_loss = torch.tensor([4.938850402832], device=device) / target_lengths assert torch.allclose(torch_loss, k2_loss) assert torch.allclose(torch_loss, expected_loss) torch_loss.backward() k2_loss.backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def test(self): for device in self.devices: fsa1 = k2.ctc_topo(5, device=device) fsa1.attr1 = torch.tensor([1] * fsa1.num_arcs, device=device) stream1 = k2.RnntDecodingStream(fsa1) fsa2 = k2.trivial_graph(3, device=device) fsa2.attr1 = torch.tensor([2] * fsa2.num_arcs, device=device) fsa2.attr2 = torch.tensor([22] * fsa2.num_arcs, device=device) stream2 = k2.RnntDecodingStream(fsa2) fsa3 = k2.ctc_topo(3, modified=True, device=device) fsa3.attr3 = k2.RaggedTensor( torch.ones((fsa3.num_arcs, 2), dtype=torch.int32, device=device) * 3 ) stream3 = k2.RnntDecodingStream(fsa3) config = k2.RnntDecodingConfig(10, 2, 3.0, 3, 3) streams = k2.RnntDecodingStreams( [stream1, stream2, stream3], config ) for i in range(5): shape, context = streams.get_contexts() logprobs = torch.randn( (context.shape[0], 10), dtype=torch.float32, device=device ) streams.advance(logprobs) streams.terminate_and_flush_to_streams() ofsa = streams.format_output([3, 4, 5]) print(ofsa)
def test_random_case1(self): # 1 sequence for device in self.devices: T = torch.randint(10, 100, (1, )).item() C = torch.randint(20, 30, (1, )).item() torch_activation = torch.rand((1, T + 10, C), dtype=torch.float32, device=device).requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) # [N, T, C] -> [T, N, C] torch_log_probs = torch.nn.functional.log_softmax( torch_activation.permute(1, 0, 2), dim=-1) input_lengths = torch.tensor([T]).to(device) target_lengths = torch.randint(1, T, (1, )).to(device) targets = torch.randint(1, C - 1, (target_lengths.item(), )).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='mean') k2_log_probs = torch.nn.functional.log_softmax(k2_activation, dim=-1) supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo = k2.ctc_topo(C - 1) linear_fsa = k2.linear_fsa([targets.tolist()]) decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device) k2_loss = k2.ctc_loss(decoding_graph, dense_fsa_vec, reduction='mean', target_lengths=target_lengths) assert torch.allclose(torch_loss, k2_loss) scale = torch.rand_like(torch_loss) * 100 (torch_loss * scale).sum().backward() (k2_loss * scale).sum().backward() assert torch.allclose(torch_activation.grad, k2_activation.grad, atol=1e-2)
def _visualize_ctc_topo(): '''See https://git.io/JtqyJ for what the resulting ctc_topo looks like. ''' symbols = k2.SymbolTable.from_str(''' <blk> 0 a 1 b 2 ''') aux_symbols = k2.SymbolTable.from_str(''' a 1 b 2 ''') ctc_topo = k2.ctc_topo(2) ctc_topo.labels_sym = symbols ctc_topo.aux_labels_sym = aux_symbols ctc_topo.draw('ctc_topo.pdf')
def test_random_case2(self): # 2 sequences for device in self.devices: T1 = torch.randint(10, 200, (1, )).item() T2 = torch.randint(9, 100, (1, )).item() C = torch.randint(20, 30, (1, )).item() if T1 < T2: T1, T2 = T2, T1 torch_activation_1 = torch.rand((T1, C), dtype=torch.float32, device=device).requires_grad_(True) torch_activation_2 = torch.rand((T2, C), dtype=torch.float32, device=device).requires_grad_(True) k2_activation_1 = torch_activation_1.detach().clone( ).requires_grad_(True) k2_activation_2 = torch_activation_2.detach().clone( ).requires_grad_(True) # [T, N, C] torch_activations = torch.nn.utils.rnn.pad_sequence( [torch_activation_1, torch_activation_2], batch_first=False, padding_value=0) # [N, T, C] k2_activations = torch.nn.utils.rnn.pad_sequence( [k2_activation_1, k2_activation_2], batch_first=True, padding_value=0) target_length1 = torch.randint(1, T1, (1, )).item() target_length2 = torch.randint(1, T2, (1, )).item() target_lengths = torch.tensor([target_length1, target_length2]).to(device) targets = torch.randint(1, C - 1, (target_lengths.sum(), )).to(device) # [T, N, C] torch_log_probs = torch.nn.functional.log_softmax( torch_activations, dim=-1) input_lengths = torch.tensor([T1, T2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='mean') assert T1 >= T2 supervision_segments = torch.tensor([[0, 0, T1], [1, 0, T2]], dtype=torch.int32) k2_log_probs = torch.nn.functional.log_softmax(k2_activations, dim=-1) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo = k2.ctc_topo(C - 1) linear_fsa = k2.linear_fsa([ targets[:target_length1].tolist(), targets[target_length1:].tolist() ]) decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device) k2_loss = k2.ctc_loss(decoding_graph, dense_fsa_vec, reduction='mean', target_lengths=target_lengths) assert torch.allclose(torch_loss, k2_loss) scale = torch.rand_like(torch_loss) * 100 (torch_loss * scale).sum().backward() (k2_loss * scale).sum().backward() assert torch.allclose(torch_activation_1.grad, k2_activation_1.grad, atol=1e-2) assert torch.allclose(torch_activation_2.grad, k2_activation_2.grad, atol=1e-2)
def test_case4(self): for device in self.devices: # put case3, case2 and case1 into a batch torch_activation_1 = torch.tensor( [[0., 0., 0., 0., 0.]]).to(device).requires_grad_(True) torch_activation_2 = torch.arange(1, 16).reshape(3, 5).to( torch.float32).to(device).requires_grad_(True) torch_activation_3 = torch.tensor([ [-5, -4, -3, -2, -1], [-10, -9, -8, -7, -6], [-15, -14, -13, -12, -11.], ]).to(device).requires_grad_(True) k2_activation_1 = torch_activation_1.detach().clone( ).requires_grad_(True) k2_activation_2 = torch_activation_2.detach().clone( ).requires_grad_(True) k2_activation_3 = torch_activation_3.detach().clone( ).requires_grad_(True) # [T, N, C] torch_activations = torch.nn.utils.rnn.pad_sequence( [torch_activation_3, torch_activation_2, torch_activation_1], batch_first=False, padding_value=0) # [N, T, C] k2_activations = torch.nn.utils.rnn.pad_sequence( [k2_activation_3, k2_activation_2, k2_activation_1], batch_first=True, padding_value=0) # [[b,c], [c,c], [a]] targets = torch.tensor([2, 3, 3, 3, 1]).to(device) input_lengths = torch.tensor([3, 3, 1]).to(device) target_lengths = torch.tensor([2, 2, 1]).to(device) torch_log_probs = torch.nn.functional.log_softmax( torch_activations, dim=-1) # (T, N, C) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='sum') expected_loss = torch.tensor( [4.938850402832, 7.355742931366, 1.6094379425049]).sum() assert torch.allclose(torch_loss, expected_loss.to(device)) k2_log_probs = torch.nn.functional.log_softmax(k2_activations, dim=-1) supervision_segments = torch.tensor( [[0, 0, 3], [1, 0, 3], [2, 0, 1]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo = k2.ctc_topo(4) # [ [b, c], [c, c], [a]] linear_fsa = k2.linear_fsa([[2, 3], [3, 3], [1]]) decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device) k2_loss = k2.ctc_loss(decoding_graph, dense_fsa_vec, reduction='sum', target_lengths=target_lengths) assert torch.allclose(torch_loss, k2_loss) scale = torch.tensor([1., -2, 3.5]).to(device) (torch_loss * scale).sum().backward() (k2_loss * scale).sum().backward() assert torch.allclose(torch_activation_1.grad, k2_activation_1.grad) assert torch.allclose(torch_activation_2.grad, k2_activation_2.grad) assert torch.allclose(torch_activation_3.grad, k2_activation_3.grad)
def visualize_ctc_topo(): '''This function shows how to visualize standard/modified ctc topologies. It's for demonstration only, not for testing. ''' max_token = 2 labels_sym = k2.SymbolTable.from_str(''' <blk> 0 z 1 o 2 ''') aux_labels_sym = k2.SymbolTable.from_str(''' z 1 o 2 ''') word_sym = k2.SymbolTable.from_str(''' zoo 1 ''') standard = k2.ctc_topo(max_token, modified=False) modified = k2.ctc_topo(max_token, modified=True) standard.labels_sym = labels_sym standard.aux_labels_sym = aux_labels_sym modified.labels_sym = labels_sym modified.aux_labels_sym = aux_labels_sym standard.draw('standard_topo.svg', title='standard CTC topo') modified.draw('modified_topo.svg', title='modified CTC topo') fsa = k2.linear_fst([1, 2, 2], [1, 0, 0]) fsa.labels_sym = labels_sym fsa.aux_labels_sym = word_sym fsa.draw('transcript.svg', title='transcript') standard_graph = k2.compose(standard, fsa) modified_graph = k2.compose(modified, fsa) standard_graph.draw('standard_graph.svg', title='standard graph') modified_graph.draw('modified_graph.svg', title='modified graph') # z z <blk> <blk> o o <blk> o <blk> inputs = k2.linear_fsa([1, 1, 0, 0, 2, 2, 0, 2, 0]) inputs.labels_sym = labels_sym inputs.draw('inputs.svg', title='inputs') standard_lattice = k2.intersect(standard_graph, inputs, treat_epsilons_specially=False) standard_lattice.draw('standard_lattice.svg', title='standard lattice') modified_lattice = k2.intersect(modified_graph, inputs, treat_epsilons_specially=False) modified_lattice = k2.connect(modified_lattice) modified_lattice.draw('modified_lattice.svg', title='modified lattice') # z z <blk> <blk> o o o <blk> inputs2 = k2.linear_fsa([1, 1, 0, 0, 2, 2, 2, 0]) inputs2.labels_sym = labels_sym inputs2.draw('inputs2.svg', title='inputs2') standard_lattice2 = k2.intersect(standard_graph, inputs2, treat_epsilons_specially=False) standard_lattice2 = k2.connect(standard_lattice2) # It's empty since the topo requires that there must be a blank # between the two o's in zoo assert standard_lattice2.num_arcs == 0 standard_lattice2.draw('standard_lattice2.svg', title='standard lattice2') modified_lattice2 = k2.intersect(modified_graph, inputs2, treat_epsilons_specially=False) modified_lattice2 = k2.connect(modified_lattice2) modified_lattice2.draw('modified_lattice2.svg', title='modified lattice2')