def test_tf_bleurt_positional_args_error(self): # Creates the TF Graph. bleurt_ops = score.create_bleurt_ops() tfcandidates = tf.constant(candidates) tfreferences = tf.constant(references) with self.assertRaises(AssertionError): _ = bleurt_ops(tfreferences, tfcandidates)
def test_tf_bleurt_score_eager(self): # Creates the TF Graph. bleurt_ops = score.create_bleurt_ops() tfcandidates = tf.constant(candidates) tfreferences = tf.constant(references) bleurt_out = bleurt_ops(tfreferences, tfcandidates) # Computes the BLEURT scores. self.assertIn("predictions", bleurt_out) self.assertEqual(bleurt_out["predictions"].shape, (2,)) self.assertAllClose(bleurt_out["predictions"], ref_scores)
def test_tf_bleurt_score_not_eager(self): with self.session(graph=tf.Graph()) as session: # Creates the TF Graph. bleurt_ops = score.create_bleurt_ops() bleurt_scores = bleurt_ops(references, candidates) # Runs init. init_op = tf.group(tf.global_variables_initializer(), tf.tables_initializer()) session.run(init_op) # Computes the BLEURT scores. bleurt_out = session.run(bleurt_scores) self.assertIn("predictions", bleurt_out) self.assertEqual(bleurt_out["predictions"].shape, (2, )) self.assertAllClose(bleurt_out["predictions"], ref_scores)
def __init__(self, **kwargs): super(BleurtLayer, self).__init__(**kwargs) self.bleurt_ops = bleurt_score.create_bleurt_ops(bleurt_model)