コード例 #1
0
    def test_psl_rule_1(self):
        rule_weights = (1.0, )
        rule_names = ('rule_1', )
        psl_constraints = model.PSLModelMultiWoZ(rule_weights,
                                                 rule_names,
                                                 config=self.config)
        logits = test_util.LOGITS

        loss = psl_constraints.rule_1(logits=tf.constant(logits))
        self.assertEqual(loss, 1.4)
コード例 #2
0
    def test_compute_loss(self):
        rule_weights = (1.0, 2.0)
        rule_names = ('rule_11', 'rule_12')
        psl_constraints = model.PSLModelMultiWoZ(rule_weights,
                                                 rule_names,
                                                 config=self.config)
        logits = test_util.LOGITS

        loss = psl_constraints.compute_loss(logits=tf.constant(logits),
                                            data=test_util.FEATURES)
        self.assertNear(loss, 0.9, err=1e-6)
コード例 #3
0
    def test_psl_rule_12(self):
        rule_weights = (1.0, )
        rule_names = ('rule_12', )
        psl_constraints = model.PSLModelMultiWoZ(rule_weights,
                                                 rule_names,
                                                 config=self.config)
        logits = test_util.LOGITS

        loss = psl_constraints.rule_12(logits=tf.constant(logits),
                                       data=test_util.FEATURES)
        self.assertNear(loss, 0.1, err=1e-6)
コード例 #4
0
    def test_psl_rule_7_run_model(self):
        rule_weights = (1.0, )
        rule_names = ('rule_7', )
        psl_constraints = model.PSLModelMultiWoZ(rule_weights,
                                                 rule_names,
                                                 config=self.config)

        constrained_model = test_util.build_constrained_model([
            self.config['max_dialog_size'], self.config['max_utterance_size']
        ])
        constrained_model.fit(self.train_ds,
                              epochs=self.config['train_epochs'])

        logits = eval_model.evaluate_constrained_model(constrained_model,
                                                       self.test_ds,
                                                       psl_constraints)
        predictions = tf.math.argmax(logits[0], axis=-1)
        self.assertEqual(predictions[1][2], self.config['class_map']['end'])
        self.assertEqual(predictions[2][3], self.config['class_map']['end'])
コード例 #5
0
    def test_psl_rule_1_run_model(self):
        rule_weights = (1.0, )
        rule_names = ('rule_1', )
        psl_constraints = model.PSLModelMultiWoZ(rule_weights,
                                                 rule_names,
                                                 config=self.config)

        constrained_model = test_util.build_constrained_model([
            self.config['max_dialog_size'], self.config['max_utterance_size']
        ])
        constrained_model.fit(self.train_ds,
                              epochs=self.config['train_epochs'])

        logits = eval_model.evaluate_constrained_model(constrained_model,
                                                       self.test_ds,
                                                       psl_constraints)
        predictions = tf.math.argmax(logits[0], axis=-1)
        result = self.check_greet(predictions, self.test_labels[1],
                                  self.config['class_map'])
        self.assertTrue(result)
コード例 #6
0
def get_psl_model(dataset: str, rule_names: List[str],
                  rule_weights: List[float], **kwargs) -> psl_model.PSLModel:
  """Constraints PSL constraint model."""
  _check_dataset_supported(dataset)
  return psl_model_multiwoz.PSLModelMultiWoZ(rule_weights, rule_names, **kwargs)
コード例 #7
0
    def test_loss_shape(self):
        config = self._create_test_vrnn_config()
        batch_size = 2
        labels_1 = tf.keras.Input(shape=(config.max_dialog_length, None),
                                  batch_size=batch_size)
        labels_2 = tf.keras.Input(shape=(config.max_dialog_length, None),
                                  batch_size=batch_size)
        labels_1_mask = tf.keras.Input(shape=(config.max_dialog_length, None),
                                       batch_size=batch_size)
        labels_2_mask = tf.keras.Input(shape=(config.max_dialog_length, None),
                                       batch_size=batch_size)
        model_outputs = [
            tf.keras.Input(shape=(config.max_dialog_length, None),
                           batch_size=batch_size),
            tf.keras.Input(shape=(config.max_dialog_length, None),
                           batch_size=batch_size),
            tf.keras.Input(shape=(config.max_dialog_length, config.num_states),
                           batch_size=batch_size),
            tf.keras.Input(shape=(config.max_dialog_length, config.num_states),
                           batch_size=batch_size),
            tf.keras.Input(shape=(config.max_dialog_length,
                                  config.vae_cell.max_seq_length, None),
                           batch_size=batch_size),
            tf.keras.Input(shape=(config.max_dialog_length,
                                  config.vae_cell.max_seq_length, None),
                           batch_size=batch_size),
            tf.keras.Input(shape=(config.max_dialog_length, None),
                           batch_size=batch_size),
            tf.keras.Input(shape=(config.max_dialog_length, None),
                           batch_size=batch_size),
        ]
        latent_label_id = tf.keras.Input(shape=(config.max_dialog_length, ),
                                         batch_size=batch_size,
                                         dtype=tf.int32)
        latent_label_mask = tf.keras.Input(shape=(config.max_dialog_length, ),
                                           batch_size=batch_size)
        word_weights = np.ones((config.vocab_size), dtype=np.float32)

        psl_config = psl_test_util.TEST_MULTIWOZ_CONFIG
        rule_weights = (1.0, )
        rule_names = ('rule_1', )
        psl_constraints = psl_model_multiwoz.PSLModelMultiWoZ(
            rule_weights, rule_names, config=psl_config)
        psl_inputs = tf.keras.Input(shape=(config.max_dialog_length, 8),
                                    batch_size=batch_size)
        psl_constraint_loss_weight = 0.1

        outputs = linear_vrnn.compute_loss(
            labels_1,
            labels_2,
            labels_1_mask,
            labels_2_mask,
            model_outputs,
            latent_label_id,
            latent_label_mask,
            word_weights,
            with_bpr=True,
            kl_loss_weight=0.5,
            with_bow=True,
            bow_loss_weight=0.3,
            num_latent_states=config.num_states,
            classification_loss_weight=0.8,
            psl_constraint_model=psl_constraints,
            psl_inputs=psl_inputs,
            psl_constraint_loss_weight=psl_constraint_loss_weight)

        self.assertLen(outputs, 8)
        for loss in outputs[:7]:
            self.assertEqual([], loss.shape.as_list())

        loss_per_rule = outputs[7]
        self.assertLen(loss_per_rule, len(rule_names))
        for loss in loss_per_rule:
            self.assertEqual([], loss.shape.as_list())