Exemple #1
0
 def test_tf_bleurt_positional_args_error(self):
     # Creates the TF Graph.
     bleurt_ops = score.create_bleurt_ops()
     tfcandidates = tf.constant(candidates)
     tfreferences = tf.constant(references)
     with self.assertRaises(AssertionError):
         _ = bleurt_ops(tfreferences, tfcandidates)
  def test_tf_bleurt_score_eager(self):
    # Creates the TF Graph.
    bleurt_ops = score.create_bleurt_ops()
    tfcandidates = tf.constant(candidates)
    tfreferences = tf.constant(references)
    bleurt_out = bleurt_ops(tfreferences, tfcandidates)

    # Computes the BLEURT scores.
    self.assertIn("predictions", bleurt_out)
    self.assertEqual(bleurt_out["predictions"].shape, (2,))
    self.assertAllClose(bleurt_out["predictions"], ref_scores)
Exemple #3
0
    def test_tf_bleurt_score_not_eager(self):
        with self.session(graph=tf.Graph()) as session:
            # Creates the TF Graph.
            bleurt_ops = score.create_bleurt_ops()
            bleurt_scores = bleurt_ops(references, candidates)

            # Runs init.
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.tables_initializer())
            session.run(init_op)

            # Computes the BLEURT scores.
            bleurt_out = session.run(bleurt_scores)

        self.assertIn("predictions", bleurt_out)
        self.assertEqual(bleurt_out["predictions"].shape, (2, ))
        self.assertAllClose(bleurt_out["predictions"], ref_scores)
Exemple #4
0
 def __init__(self, **kwargs):
     super(BleurtLayer, self).__init__(**kwargs)
     self.bleurt_ops = bleurt_score.create_bleurt_ops(bleurt_model)