Ejemplo n.º 1
0
    def test_pooling_batcher(self):
        data = [
            RowData({"text": "something"}, {
                "a": i,
                "b": 10 + i,
                "c": 20 + i
            }) for i in range(10)
        ]
        batcher = PoolingBatcher(train_batch_size=3, pool_num_batches=2)
        batches = list(
            batcher.batchify(data, sort_key=lambda x: x.numberized["a"]))

        self.assertEqual(len(batches), 4)
        a_vals = {a for raw_batch, batch in batches for a in batch["a"]}
        self.assertSetEqual(a_vals, set(range(10)))
        for raw_batch, batch in batches[:2]:
            self.assertEqual([{
                "text": "something"
            }] * len(raw_batch), list(raw_batch))
            self.assertGreater(batch["a"][0], batch["a"][-1])
            for a in batch["a"]:
                self.assertLess(a, 6)
        for _, batch in batches[2:]:
            for a in batch["a"]:
                self.assertGreaterEqual(a, 6)
Ejemplo n.º 2
0
 def _yield_and_reset(self, row):
     packed_tokens = list(self.remainder["tokens"])
     packed_segments = list(self.remainder["segment_labels"])
     self.remainder: Dict[str, List[int]] = {
         "tokens": [],
         "segment_labels": []
     }
     return RowData(
         row,
         self._format_output_row(packed_tokens, packed_segments,
                                 len(packed_tokens)),
     )
Ejemplo n.º 3
0
 def _yield_and_reset(self):
     packed_tokens = list(self.remainder["tokens"])
     packed_segments = list(self.remainder["segment_labels"])
     self.remainder: Dict[str, List[int]] = {
         "tokens": [],
         "segment_labels": []
     }
     return RowData(
         {},  # packed LM data doesn't respect data cardinality
         self._format_output_row(packed_tokens, packed_segments,
                                 len(packed_tokens)),
     )
Ejemplo n.º 4
0
 def test_batcher(self):
     data = [
         RowData({"text": "something"}, {
             "a": i,
             "b": 10 + i,
             "c": 20 + i
         }) for i in range(10)
     ]
     batcher = Batcher(train_batch_size=3)
     batches = list(batcher.batchify(data))
     self.assertEqual(len(batches), 4)
     self.assertEqual(len(batches[0].raw_data), 3)
     self.assertEqual("something", batches[1].raw_data[0]["text"])
     self.assertEqual(batches[1].numberized["a"], [3, 4, 5])
     self.assertEqual(batches[3].numberized["b"], [19])