示例#1
0
    def test_psl_rule_7_run_model(self):
        rule_weights = (1.0, )
        rule_names = ('rule_7', )
        psl_constraints = model.PSLModelMultiWoZ(rule_weights,
                                                 rule_names,
                                                 config=self.config)

        constrained_model = test_util.build_constrained_model([
            self.config['max_dialog_size'], self.config['max_utterance_size']
        ])
        constrained_model.fit(self.train_ds,
                              epochs=self.config['train_epochs'])

        logits = eval_model.evaluate_constrained_model(constrained_model,
                                                       self.test_ds,
                                                       psl_constraints)
        predictions = tf.math.argmax(logits[0], axis=-1)
        self.assertEqual(predictions[1][2], self.config['class_map']['end'])
        self.assertEqual(predictions[2][3], self.config['class_map']['end'])
示例#2
0
    def test_psl_rule_1_run_model(self):
        rule_weights = (1.0, )
        rule_names = ('rule_1', )
        psl_constraints = model.PSLModelMultiWoZ(rule_weights,
                                                 rule_names,
                                                 config=self.config)

        constrained_model = test_util.build_constrained_model([
            self.config['max_dialog_size'], self.config['max_utterance_size']
        ])
        constrained_model.fit(self.train_ds,
                              epochs=self.config['train_epochs'])

        logits = eval_model.evaluate_constrained_model(constrained_model,
                                                       self.test_ds,
                                                       psl_constraints)
        predictions = tf.math.argmax(logits[0], axis=-1)
        result = self.check_greet(predictions, self.test_labels[1],
                                  self.config['class_map'])
        self.assertTrue(result)