def test_psl_rule_1(self): rule_weights = (1.0, ) rule_names = ('rule_1', ) psl_constraints = model.PSLModelMultiWoZ(rule_weights, rule_names, config=self.config) logits = test_util.LOGITS loss = psl_constraints.rule_1(logits=tf.constant(logits)) self.assertEqual(loss, 1.4)
def test_compute_loss(self): rule_weights = (1.0, 2.0) rule_names = ('rule_11', 'rule_12') psl_constraints = model.PSLModelMultiWoZ(rule_weights, rule_names, config=self.config) logits = test_util.LOGITS loss = psl_constraints.compute_loss(logits=tf.constant(logits), data=test_util.FEATURES) self.assertNear(loss, 0.9, err=1e-6)
def test_psl_rule_12(self): rule_weights = (1.0, ) rule_names = ('rule_12', ) psl_constraints = model.PSLModelMultiWoZ(rule_weights, rule_names, config=self.config) logits = test_util.LOGITS loss = psl_constraints.rule_12(logits=tf.constant(logits), data=test_util.FEATURES) self.assertNear(loss, 0.1, err=1e-6)
def test_psl_rule_7_run_model(self): rule_weights = (1.0, ) rule_names = ('rule_7', ) psl_constraints = model.PSLModelMultiWoZ(rule_weights, rule_names, config=self.config) constrained_model = test_util.build_constrained_model([ self.config['max_dialog_size'], self.config['max_utterance_size'] ]) constrained_model.fit(self.train_ds, epochs=self.config['train_epochs']) logits = eval_model.evaluate_constrained_model(constrained_model, self.test_ds, psl_constraints) predictions = tf.math.argmax(logits[0], axis=-1) self.assertEqual(predictions[1][2], self.config['class_map']['end']) self.assertEqual(predictions[2][3], self.config['class_map']['end'])
def test_psl_rule_1_run_model(self): rule_weights = (1.0, ) rule_names = ('rule_1', ) psl_constraints = model.PSLModelMultiWoZ(rule_weights, rule_names, config=self.config) constrained_model = test_util.build_constrained_model([ self.config['max_dialog_size'], self.config['max_utterance_size'] ]) constrained_model.fit(self.train_ds, epochs=self.config['train_epochs']) logits = eval_model.evaluate_constrained_model(constrained_model, self.test_ds, psl_constraints) predictions = tf.math.argmax(logits[0], axis=-1) result = self.check_greet(predictions, self.test_labels[1], self.config['class_map']) self.assertTrue(result)
def get_psl_model(dataset: str, rule_names: List[str], rule_weights: List[float], **kwargs) -> psl_model.PSLModel: """Constraints PSL constraint model.""" _check_dataset_supported(dataset) return psl_model_multiwoz.PSLModelMultiWoZ(rule_weights, rule_names, **kwargs)
def test_loss_shape(self): config = self._create_test_vrnn_config() batch_size = 2 labels_1 = tf.keras.Input(shape=(config.max_dialog_length, None), batch_size=batch_size) labels_2 = tf.keras.Input(shape=(config.max_dialog_length, None), batch_size=batch_size) labels_1_mask = tf.keras.Input(shape=(config.max_dialog_length, None), batch_size=batch_size) labels_2_mask = tf.keras.Input(shape=(config.max_dialog_length, None), batch_size=batch_size) model_outputs = [ tf.keras.Input(shape=(config.max_dialog_length, None), batch_size=batch_size), tf.keras.Input(shape=(config.max_dialog_length, None), batch_size=batch_size), tf.keras.Input(shape=(config.max_dialog_length, config.num_states), batch_size=batch_size), tf.keras.Input(shape=(config.max_dialog_length, config.num_states), batch_size=batch_size), tf.keras.Input(shape=(config.max_dialog_length, config.vae_cell.max_seq_length, None), batch_size=batch_size), tf.keras.Input(shape=(config.max_dialog_length, config.vae_cell.max_seq_length, None), batch_size=batch_size), tf.keras.Input(shape=(config.max_dialog_length, None), batch_size=batch_size), tf.keras.Input(shape=(config.max_dialog_length, None), batch_size=batch_size), ] latent_label_id = tf.keras.Input(shape=(config.max_dialog_length, ), batch_size=batch_size, dtype=tf.int32) latent_label_mask = tf.keras.Input(shape=(config.max_dialog_length, ), batch_size=batch_size) word_weights = np.ones((config.vocab_size), dtype=np.float32) psl_config = psl_test_util.TEST_MULTIWOZ_CONFIG rule_weights = (1.0, ) rule_names = ('rule_1', ) psl_constraints = psl_model_multiwoz.PSLModelMultiWoZ( rule_weights, rule_names, config=psl_config) psl_inputs = tf.keras.Input(shape=(config.max_dialog_length, 8), batch_size=batch_size) psl_constraint_loss_weight = 0.1 outputs = linear_vrnn.compute_loss( labels_1, labels_2, labels_1_mask, labels_2_mask, model_outputs, latent_label_id, latent_label_mask, word_weights, with_bpr=True, kl_loss_weight=0.5, with_bow=True, bow_loss_weight=0.3, num_latent_states=config.num_states, classification_loss_weight=0.8, psl_constraint_model=psl_constraints, psl_inputs=psl_inputs, psl_constraint_loss_weight=psl_constraint_loss_weight) self.assertLen(outputs, 8) for loss in outputs[:7]: self.assertEqual([], loss.shape.as_list()) loss_per_rule = outputs[7] self.assertLen(loss_per_rule, len(rule_names)) for loss in loss_per_rule: self.assertEqual([], loss.shape.as_list())