Esempio n. 1
0
    def test_batch_encoding_with_labels_tf(self):
        batch = BatchEncoding({
            "inputs": [[1, 2, 3], [4, 5, 6]],
            "labels": [0, 1]
        })
        tensor_batch = batch.convert_to_tensors(tensor_type="tf")
        self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
        self.assertEqual(tensor_batch["labels"].shape, (2, ))

        batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
        tensor_batch = batch.convert_to_tensors(tensor_type="tf",
                                                prepend_batch_axis=True)
        self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
        self.assertEqual(tensor_batch["labels"].shape, (1, ))
    def test_batch_encoding_with_labels_jax(self):
        batch = BatchEncoding({
            "inputs": [[1, 2, 3], [4, 5, 6]],
            "labels": [0, 1]
        })
        tensor_batch = batch.convert_to_tensors(tensor_type="jax")
        self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
        self.assertEqual(tensor_batch["labels"].shape, (2, ))
        # test converting the converted
        with CaptureStderr() as cs:
            tensor_batch = batch.convert_to_tensors(tensor_type="jax")
        self.assertFalse(len(cs.err),
                         msg=f"should have no warning, but got {cs.err}")

        batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
        tensor_batch = batch.convert_to_tensors(tensor_type="jax",
                                                prepend_batch_axis=True)
        self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
        self.assertEqual(tensor_batch["labels"].shape, (1, ))