def test_prune(self): dacn = DialogueActConfusionNetwork() dacn.add(0.05, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.9, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.00005, DialogueActItem(dai='inform(food=russian)')) # Russian food should be pruned. self.assertEqual(len(dacn), 3) dacn.prune() self.assertEqual(len(dacn), 2) self.assertTrue(not DialogueActItem(dai='inform(food=russian)') in dacn)
def test_prune(self): dacn = DialogueActConfusionNetwork() dacn.add(0.05, DialogueActItem(dai='inform(food=chinese)')) dacn.add(0.9, DialogueActItem(dai='inform(food=czech)')) dacn.add(0.00005, DialogueActItem(dai='inform(food=russian)')) # Russian food should be pruned. self.assertEqual(len(dacn), 3) dacn.prune() self.assertEqual(len(dacn), 2) self.assertTrue(not DialogueActItem( dai='inform(food=russian)') in dacn)
def parse_nblist(self, obs, *args, **kwargs): """ Parses an observation featuring an utterance n-best list using the parse_1_best method. Arguments: obs -- a dictionary of observations :: observation type -> observed value where observation type is one of values for `obs_type' used in `ft_props', and observed value is the corresponding observed value for the input args -- further positional arguments that should be passed to the `parse_1_best' method call kwargs -- further keyword arguments that should be passed to the `parse_1_best' method call """ nblist = obs['utt_nbl'] if len(nblist) == 0: return DialogueActConfusionNetwork() obs_wo_nblist = copy.deepcopy(obs) del obs_wo_nblist['utt_nbl'] dacn_list = [] for prob, utt in nblist: if "_other_" == utt: dacn = DialogueActConfusionNetwork() dacn.add(1.0, DialogueActItem("other")) elif "_silence_" == utt: dacn = DialogueActConfusionNetwork() dacn.add(1.0, DialogueActItem("silence")) else: obs_wo_nblist['utt'] = utt dacn = self.parse_1_best(obs_wo_nblist, *args, **kwargs) dacn_list.append((prob, dacn)) dacn = merge_slu_confnets(dacn_list) dacn.prune() dacn.sort() return dacn