def test_psl_rule_7_run_model(self): rule_weights = (1.0, ) rule_names = ('rule_7', ) psl_constraints = model.PSLModelMultiWoZ(rule_weights, rule_names, config=self.config) constrained_model = test_util.build_constrained_model([ self.config['max_dialog_size'], self.config['max_utterance_size'] ]) constrained_model.fit(self.train_ds, epochs=self.config['train_epochs']) logits = eval_model.evaluate_constrained_model(constrained_model, self.test_ds, psl_constraints) predictions = tf.math.argmax(logits[0], axis=-1) self.assertEqual(predictions[1][2], self.config['class_map']['end']) self.assertEqual(predictions[2][3], self.config['class_map']['end'])
def test_psl_rule_1_run_model(self): rule_weights = (1.0, ) rule_names = ('rule_1', ) psl_constraints = model.PSLModelMultiWoZ(rule_weights, rule_names, config=self.config) constrained_model = test_util.build_constrained_model([ self.config['max_dialog_size'], self.config['max_utterance_size'] ]) constrained_model.fit(self.train_ds, epochs=self.config['train_epochs']) logits = eval_model.evaluate_constrained_model(constrained_model, self.test_ds, psl_constraints) predictions = tf.math.argmax(logits[0], axis=-1) result = self.check_greet(predictions, self.test_labels[1], self.config['class_map']) self.assertTrue(result)