def test_pruned_max(self): forward_max_weights = k2host.DoubleArray1.create_array_with_size( self.num_states) backward_max_weights = k2host.DoubleArray1.create_array_with_size( self.num_states) wfsa = k2host.WfsaWithFbWeights(self.fsa, k2host.FbWeightType.kMaxWeight, forward_max_weights, backward_max_weights) beam = 8.0 remover = k2host.EpsilonsRemoverPrunedMax(wfsa, beam) fsa_size = k2host.IntArray2Size() arc_derivs_size = k2host.IntArray2Size() remover.get_sizes(fsa_size, arc_derivs_size) fsa_out = k2host.Fsa.create_fsa_with_size(fsa_size) arc_derivs = k2host.IntArray2.create_array_with_size(arc_derivs_size) arc_weights_out = k2host.FloatArray1.create_array_with_size( fsa_size.size2) remover.get_output(fsa_out, arc_derivs) self.assertTrue(k2host.is_epsilon_free(fsa_out)) self.assertEqual(fsa_out.size1, 6) self.assertEqual(fsa_out.size2, 11) # TODO: fix this self.assertEqual(arc_derivs.size1, 11) # TODO: fix this self.assertEqual(arc_derivs.size2, 18) # TODO: fix this self.assertTrue( k2host.is_rand_equivalent_max_weight(self.fsa, fsa_out, beam))
def test_pruned_logsum(self): forward_logsum_weights = k2host.DoubleArray1.create_array_with_size( self.num_states) backward_logsum_weights = k2host.DoubleArray1.create_array_with_size( self.num_states) wfsa = k2host.WfsaWithFbWeights(self.fsa, k2host.FbWeightType.kLogSumWeight, forward_logsum_weights, backward_logsum_weights) beam = 10.0 determinizer = k2host.DeterminizerPrunedLogSum( wfsa, beam, 100, k2host.FbWeightType.kNoWeight) fsa_size = k2host.IntArray2Size() arc_derivs_size = k2host.IntArray2Size() determinizer.get_sizes(fsa_size, arc_derivs_size) fsa_out = k2host.Fsa.create_fsa_with_size(fsa_size) arc_derivs = k2host.LogSumArcDerivs.create_arc_derivs_with_size( arc_derivs_size) arc_weights_out = k2host.FloatArray1.create_array_with_size( fsa_size.size2) determinizer.get_output(fsa_out, arc_derivs) self.assertTrue(k2host.is_deterministic(fsa_out)) self.assertEqual(fsa_out.size1, 7) self.assertEqual(fsa_out.size2, 9) self.assertEqual(arc_derivs.size1, 9) self.assertEqual(arc_derivs.size2, 15) self.assertTrue( k2host.is_rand_equivalent_logsum_weight(self.fsa, fsa_out, beam)) # cast float to int arc_ids = k2host.StridedIntArray1.from_float_tensor(arc_derivs.data[:, 0])
def test_pruned_max(self): forward_max_weights = k2host.DoubleArray1.create_array_with_size( self.num_states) backward_max_weights = k2host.DoubleArray1.create_array_with_size( self.num_states) wfsa = k2host.WfsaWithFbWeights(self.fsa, k2host.FbWeightType.kMaxWeight, forward_max_weights, backward_max_weights) beam = 10.0 determinizer = k2host.DeterminizerPrunedMax( wfsa, beam, 100, k2host.FbWeightType.kNoWeight) fsa_size = k2host.IntArray2Size() arc_derivs_size = k2host.IntArray2Size() determinizer.get_sizes(fsa_size, arc_derivs_size) fsa_out = k2host.Fsa.create_fsa_with_size(fsa_size) arc_derivs = k2host.IntArray2.create_array_with_size(arc_derivs_size) arc_weights_out = k2host.FloatArray1.create_array_with_size( fsa_size.size2) determinizer.get_output(fsa_out, arc_derivs) self.assertTrue(k2host.is_deterministic(fsa_out)) self.assertEqual(fsa_out.size1, 7) self.assertEqual(fsa_out.size2, 9) self.assertEqual(arc_derivs.size1, 9) self.assertEqual(arc_derivs.size2, 12) self.assertTrue( k2host.is_rand_equivalent_max_weight(self.fsa, fsa_out, beam))
def test_logsum_weight(self): forward_logsum_weights = k2host.DoubleArray1.create_array_with_size( self.num_states) backward_logsum_weights = k2host.DoubleArray1.create_array_with_size( self.num_states) wfsa = k2host.WfsaWithFbWeights(self.fsa, k2host.FbWeightType.kLogSumWeight, forward_logsum_weights, backward_logsum_weights) expected_forward_logsum_weights = torch.DoubleTensor( [0, 1, 3, 4, 1, float('-inf'), 3, 9.126928, 4, 14.143222]) expected_backward_logsum_weights = torch.DoubleTensor( [14.143222, 13.126928, 9, 10, 9.018150, 4, 3, 5, 6, 0]) self.assertTrue( torch.allclose(forward_logsum_weights.data, expected_forward_logsum_weights)) self.assertTrue( torch.allclose(backward_logsum_weights.data, expected_backward_logsum_weights))
def test_max_weight(self): forward_max_weights = k2host.DoubleArray1.create_array_with_size( self.num_states) backward_max_weights = k2host.DoubleArray1.create_array_with_size( self.num_states) wfsa = k2host.WfsaWithFbWeights(self.fsa, k2host.FbWeightType.kMaxWeight, forward_max_weights, backward_max_weights) expected_forward_max_weights = torch.DoubleTensor( [0, 1, 3, 4, 1, float('-inf'), 3, 9, 4, 14]) expected_backward_max_weights = torch.DoubleTensor( [14, 13, 9, 10, 9, 4, 3, 5, 6, 0]) self.assertTrue( torch.equal(forward_max_weights.data, expected_forward_max_weights)) self.assertTrue( torch.allclose(forward_max_weights.data, expected_forward_max_weights)) self.assertTrue( torch.allclose(backward_max_weights.data, expected_backward_max_weights))
def test_pruned_logsum(self): forward_logsum_weights = k2host.DoubleArray1.create_array_with_size( self.num_states) backward_logsum_weights = k2host.DoubleArray1.create_array_with_size( self.num_states) wfsa = k2host.WfsaWithFbWeights(self.fsa, k2host.FbWeightType.kLogSumWeight, forward_logsum_weights, backward_logsum_weights) beam = 8.0 remover = k2host.EpsilonsRemoverPrunedLogSum(wfsa, beam) fsa_size = k2host.IntArray2Size() arc_derivs_size = k2host.IntArray2Size() remover.get_sizes(fsa_size, arc_derivs_size) fsa_out = k2host.Fsa.create_fsa_with_size(fsa_size) arc_derivs = k2host.LogSumArcDerivs.create_arc_derivs_with_size( arc_derivs_size) arc_weights_out = k2host.FloatArray1.create_array_with_size( fsa_size.size2) remover.get_output(fsa_out, arc_derivs) self.assertTrue(k2host.is_epsilon_free(fsa_out)) self.assertEqual(fsa_out.size1, 6) self.assertEqual(fsa_out.size2, 11) # TODO: fix this self.assertEqual(arc_derivs.size1, 11) # TODO: fix this self.assertEqual(arc_derivs.size2, 20) # TODO: fix this # TODO(haowen): uncomment this after re-implementing # IsRandEquivalentAfterRmEpsPrunedLogSum #self.assertTrue( # k2host.is_rand_equivalent_after_rmeps_pruned_logsum( # self.fsa, fsa_out, beam)) # cast float to int arc_ids = k2host.StridedIntArray1.from_float_tensor(arc_derivs.data[:, 0]) # we may get different value of `arc_ids.get_data(1)` # with different STL implementations as we use # `std::unordered_map` in implementation of rmepsilon, # thus below assertion may fail on some platforms. self.assertEqual(arc_ids.get_data(1), 1)