def test_batch_encoding_with_labels_tf(self): batch = BatchEncoding({ "inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1] }) tensor_batch = batch.convert_to_tensors(tensor_type="tf") self.assertEqual(tensor_batch["inputs"].shape, (2, 3)) self.assertEqual(tensor_batch["labels"].shape, (2, )) batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0}) tensor_batch = batch.convert_to_tensors(tensor_type="tf", prepend_batch_axis=True) self.assertEqual(tensor_batch["inputs"].shape, (1, 3)) self.assertEqual(tensor_batch["labels"].shape, (1, ))
def test_batch_encoding_with_labels_jax(self): batch = BatchEncoding({ "inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1] }) tensor_batch = batch.convert_to_tensors(tensor_type="jax") self.assertEqual(tensor_batch["inputs"].shape, (2, 3)) self.assertEqual(tensor_batch["labels"].shape, (2, )) # test converting the converted with CaptureStderr() as cs: tensor_batch = batch.convert_to_tensors(tensor_type="jax") self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}") batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0}) tensor_batch = batch.convert_to_tensors(tensor_type="jax", prepend_batch_axis=True) self.assertEqual(tensor_batch["inputs"].shape, (1, 3)) self.assertEqual(tensor_batch["labels"].shape, (1, ))