Esempio n. 1
0
 def test_pruned_max(self):
     forward_max_weights = k2host.DoubleArray1.create_array_with_size(
         self.num_states)
     backward_max_weights = k2host.DoubleArray1.create_array_with_size(
         self.num_states)
     wfsa = k2host.WfsaWithFbWeights(self.fsa,
                                     k2host.FbWeightType.kMaxWeight,
                                     forward_max_weights,
                                     backward_max_weights)
     beam = 8.0
     remover = k2host.EpsilonsRemoverPrunedMax(wfsa, beam)
     fsa_size = k2host.IntArray2Size()
     arc_derivs_size = k2host.IntArray2Size()
     remover.get_sizes(fsa_size, arc_derivs_size)
     fsa_out = k2host.Fsa.create_fsa_with_size(fsa_size)
     arc_derivs = k2host.IntArray2.create_array_with_size(arc_derivs_size)
     arc_weights_out = k2host.FloatArray1.create_array_with_size(
         fsa_size.size2)
     remover.get_output(fsa_out, arc_derivs)
     self.assertTrue(k2host.is_epsilon_free(fsa_out))
     self.assertEqual(fsa_out.size1, 6)
     self.assertEqual(fsa_out.size2, 11)  # TODO: fix this
     self.assertEqual(arc_derivs.size1, 11)  # TODO: fix this
     self.assertEqual(arc_derivs.size2, 18)  # TODO: fix this
     self.assertTrue(
         k2host.is_rand_equivalent_max_weight(self.fsa, fsa_out, beam))
Esempio n. 2
0
 def test_pruned_logsum(self):
     forward_logsum_weights = k2host.DoubleArray1.create_array_with_size(
         self.num_states)
     backward_logsum_weights = k2host.DoubleArray1.create_array_with_size(
         self.num_states)
     wfsa = k2host.WfsaWithFbWeights(self.fsa,
                                     k2host.FbWeightType.kLogSumWeight,
                                     forward_logsum_weights,
                                     backward_logsum_weights)
     beam = 10.0
     determinizer = k2host.DeterminizerPrunedLogSum(
         wfsa, beam, 100, k2host.FbWeightType.kNoWeight)
     fsa_size = k2host.IntArray2Size()
     arc_derivs_size = k2host.IntArray2Size()
     determinizer.get_sizes(fsa_size, arc_derivs_size)
     fsa_out = k2host.Fsa.create_fsa_with_size(fsa_size)
     arc_derivs = k2host.LogSumArcDerivs.create_arc_derivs_with_size(
         arc_derivs_size)
     arc_weights_out = k2host.FloatArray1.create_array_with_size(
         fsa_size.size2)
     determinizer.get_output(fsa_out, arc_derivs)
     self.assertTrue(k2host.is_deterministic(fsa_out))
     self.assertEqual(fsa_out.size1, 7)
     self.assertEqual(fsa_out.size2, 9)
     self.assertEqual(arc_derivs.size1, 9)
     self.assertEqual(arc_derivs.size2, 15)
     self.assertTrue(
         k2host.is_rand_equivalent_logsum_weight(self.fsa, fsa_out, beam))
     # cast float to int
     arc_ids = k2host.StridedIntArray1.from_float_tensor(arc_derivs.data[:,
                                                                         0])
Esempio n. 3
0
 def test_pruned_max(self):
     forward_max_weights = k2host.DoubleArray1.create_array_with_size(
         self.num_states)
     backward_max_weights = k2host.DoubleArray1.create_array_with_size(
         self.num_states)
     wfsa = k2host.WfsaWithFbWeights(self.fsa,
                                     k2host.FbWeightType.kMaxWeight,
                                     forward_max_weights,
                                     backward_max_weights)
     beam = 10.0
     determinizer = k2host.DeterminizerPrunedMax(
         wfsa, beam, 100, k2host.FbWeightType.kNoWeight)
     fsa_size = k2host.IntArray2Size()
     arc_derivs_size = k2host.IntArray2Size()
     determinizer.get_sizes(fsa_size, arc_derivs_size)
     fsa_out = k2host.Fsa.create_fsa_with_size(fsa_size)
     arc_derivs = k2host.IntArray2.create_array_with_size(arc_derivs_size)
     arc_weights_out = k2host.FloatArray1.create_array_with_size(
         fsa_size.size2)
     determinizer.get_output(fsa_out, arc_derivs)
     self.assertTrue(k2host.is_deterministic(fsa_out))
     self.assertEqual(fsa_out.size1, 7)
     self.assertEqual(fsa_out.size2, 9)
     self.assertEqual(arc_derivs.size1, 9)
     self.assertEqual(arc_derivs.size2, 12)
     self.assertTrue(
         k2host.is_rand_equivalent_max_weight(self.fsa, fsa_out, beam))
Esempio n. 4
0
 def test_logsum_weight(self):
     forward_logsum_weights = k2host.DoubleArray1.create_array_with_size(
         self.num_states)
     backward_logsum_weights = k2host.DoubleArray1.create_array_with_size(
         self.num_states)
     wfsa = k2host.WfsaWithFbWeights(self.fsa,
                                     k2host.FbWeightType.kLogSumWeight,
                                     forward_logsum_weights,
                                     backward_logsum_weights)
     expected_forward_logsum_weights = torch.DoubleTensor(
         [0, 1, 3, 4, 1,
          float('-inf'), 3, 9.126928, 4, 14.143222])
     expected_backward_logsum_weights = torch.DoubleTensor(
         [14.143222, 13.126928, 9, 10, 9.018150, 4, 3, 5, 6, 0])
     self.assertTrue(
         torch.allclose(forward_logsum_weights.data,
                        expected_forward_logsum_weights))
     self.assertTrue(
         torch.allclose(backward_logsum_weights.data,
                        expected_backward_logsum_weights))
Esempio n. 5
0
 def test_max_weight(self):
     forward_max_weights = k2host.DoubleArray1.create_array_with_size(
         self.num_states)
     backward_max_weights = k2host.DoubleArray1.create_array_with_size(
         self.num_states)
     wfsa = k2host.WfsaWithFbWeights(self.fsa,
                                     k2host.FbWeightType.kMaxWeight,
                                     forward_max_weights,
                                     backward_max_weights)
     expected_forward_max_weights = torch.DoubleTensor(
         [0, 1, 3, 4, 1, float('-inf'), 3, 9, 4, 14])
     expected_backward_max_weights = torch.DoubleTensor(
         [14, 13, 9, 10, 9, 4, 3, 5, 6, 0])
     self.assertTrue(
         torch.equal(forward_max_weights.data,
                     expected_forward_max_weights))
     self.assertTrue(
         torch.allclose(forward_max_weights.data,
                        expected_forward_max_weights))
     self.assertTrue(
         torch.allclose(backward_max_weights.data,
                        expected_backward_max_weights))
Esempio n. 6
0
 def test_pruned_logsum(self):
     forward_logsum_weights = k2host.DoubleArray1.create_array_with_size(
         self.num_states)
     backward_logsum_weights = k2host.DoubleArray1.create_array_with_size(
         self.num_states)
     wfsa = k2host.WfsaWithFbWeights(self.fsa,
                                     k2host.FbWeightType.kLogSumWeight,
                                     forward_logsum_weights,
                                     backward_logsum_weights)
     beam = 8.0
     remover = k2host.EpsilonsRemoverPrunedLogSum(wfsa, beam)
     fsa_size = k2host.IntArray2Size()
     arc_derivs_size = k2host.IntArray2Size()
     remover.get_sizes(fsa_size, arc_derivs_size)
     fsa_out = k2host.Fsa.create_fsa_with_size(fsa_size)
     arc_derivs = k2host.LogSumArcDerivs.create_arc_derivs_with_size(
         arc_derivs_size)
     arc_weights_out = k2host.FloatArray1.create_array_with_size(
         fsa_size.size2)
     remover.get_output(fsa_out, arc_derivs)
     self.assertTrue(k2host.is_epsilon_free(fsa_out))
     self.assertEqual(fsa_out.size1, 6)
     self.assertEqual(fsa_out.size2, 11)  # TODO: fix this
     self.assertEqual(arc_derivs.size1, 11)  # TODO: fix this
     self.assertEqual(arc_derivs.size2, 20)  # TODO: fix this
     # TODO(haowen): uncomment this after re-implementing
     # IsRandEquivalentAfterRmEpsPrunedLogSum
     #self.assertTrue(
     #    k2host.is_rand_equivalent_after_rmeps_pruned_logsum(
     #        self.fsa, fsa_out, beam))
     # cast float to int
     arc_ids = k2host.StridedIntArray1.from_float_tensor(arc_derivs.data[:,
                                                                         0])
     # we may get different value of `arc_ids.get_data(1)`
     # with different STL implementations as we use
     # `std::unordered_map` in implementation of rmepsilon,
     # thus below assertion may fail on some platforms.
     self.assertEqual(arc_ids.get_data(1), 1)