def test_pruned_logsum(self): forward_logsum_weights = k2host.DoubleArray1.create_array_with_size( self.num_states) backward_logsum_weights = k2host.DoubleArray1.create_array_with_size( self.num_states) wfsa = k2host.WfsaWithFbWeights(self.fsa, k2host.FbWeightType.kLogSumWeight, forward_logsum_weights, backward_logsum_weights) beam = 10.0 determinizer = k2host.DeterminizerPrunedLogSum( wfsa, beam, 100, k2host.FbWeightType.kNoWeight) fsa_size = k2host.IntArray2Size() arc_derivs_size = k2host.IntArray2Size() determinizer.get_sizes(fsa_size, arc_derivs_size) fsa_out = k2host.Fsa.create_fsa_with_size(fsa_size) arc_derivs = k2host.LogSumArcDerivs.create_arc_derivs_with_size( arc_derivs_size) arc_weights_out = k2host.FloatArray1.create_array_with_size( fsa_size.size2) determinizer.get_output(fsa_out, arc_derivs) self.assertTrue(k2host.is_deterministic(fsa_out)) self.assertEqual(fsa_out.size1, 7) self.assertEqual(fsa_out.size2, 9) self.assertEqual(arc_derivs.size1, 9) self.assertEqual(arc_derivs.size2, 15) self.assertTrue( k2host.is_rand_equivalent_logsum_weight(self.fsa, fsa_out, beam)) # cast float to int arc_ids = k2host.StridedIntArray1.from_float_tensor(arc_derivs.data[:, 0])
def test_logsum_weight(self): self.assertTrue( k2host.is_rand_equivalent_logsum_weight(self.fsa_a, self.fsa_b)) self.assertFalse( k2host.is_rand_equivalent_logsum_weight(self.fsa_a, self.fsa_c))