Example #1
0
 def test_batcher(self):
     data = [{"a": i, "b": 10 + i, "c": 20 + i} for i in range(10)]
     batcher = Batcher(train_batch_size=3)
     batches = list(batcher.batchify(data))
     self.assertEqual(len(batches), 4)
     self.assertEqual(batches[1]["a"], [3, 4, 5])
     self.assertEqual(batches[3]["b"], [19])
Example #2
0
 def test_batcher(self):
     data = [
         RowData({"text": "something"}, {
             "a": i,
             "b": 10 + i,
             "c": 20 + i
         }) for i in range(10)
     ]
     batcher = Batcher(train_batch_size=3)
     batches = list(batcher.batchify(data))
     self.assertEqual(len(batches), 4)
     self.assertEqual(len(batches[0].raw_data), 3)
     self.assertEqual("something", batches[1].raw_data[0]["text"])
     self.assertEqual(batches[1].numberized["a"], [3, 4, 5])
     self.assertEqual(batches[3].numberized["b"], [19])
Example #3
0
    def test_sort(self):
        data = Data(
            self.data_source,
            self.tensorizers,
            Batcher(train_batch_size=5),
            sort_key="tokens",
        )

        def assert_sorted(batch):
            _, seq_lens, _ = batch["tokens"]
            seq_lens = seq_lens.tolist()
            for i in range(len(seq_lens) - 1):
                self.assertTrue(seq_lens[i] >= seq_lens[i + 1])

        batches = iter(list(data.batches(Stage.TRAIN)))
        first_raw_batch, first_batch = next(batches)
        assert_sorted(first_batch)
        # make sure labels are also in the same order of sorted tokens
        self.assertEqual(
            self.tensorizers["labels"].vocab[first_batch["labels"][1]],
            "alarm/set_alarm",
        )
        self.assertEqual(first_raw_batch[1][RawExampleFieldName.ROW_INDEX], 1)
        second_raw_batch, second_batch = next(batches)
        assert_sorted(second_batch)
        self.assertEqual(
            self.tensorizers["labels"].vocab[second_batch["labels"][1]],
            "alarm/time_left_on_alarm",
        )
        self.assertEqual(second_raw_batch[0][RawExampleFieldName.ROW_INDEX], 6)
        self.assertEqual(second_raw_batch[1][RawExampleFieldName.ROW_INDEX], 5)
Example #4
0
 def _get_pytext_config(
     self,
     test_file_name: TestFileName,
     task_class: Type[NewTask],
     model_class: Type[Model],
 ) -> PyTextConfig:
     test_file_metadata = get_test_file_metadata(test_file_name)
     return PyTextConfig(
         task=task_class.Config(
             data=Data.Config(
                 source=TSVDataSource.Config(
                     train_filename=test_file_metadata.filename,
                     eval_filename=test_file_metadata.filename,
                     test_filename=test_file_metadata.filename,
                     field_names=test_file_metadata.field_names,
                 ),
                 batcher=Batcher.Config(
                 ),  # Use Batcher to avoid shuffling.
             ),
             trainer=TaskTrainer.Config(epochs=1),
             model=model_class.Config(
                 inputs=type(model_class.Config.inputs)(
                     dense=FloatListTensorizer.Config(
                         column=test_file_metadata.dense_col_name,
                         dim=test_file_metadata.dense_feat_dim,
                     ))),
         ),
         use_tensorboard=False,
         use_cuda_if_available=False,
         version=LATEST_VERSION,
     )
Example #5
0
 def test_create_batches_different_tensorizers(self):
     tensorizers = {"tokens": WordTensorizer(column="text")}
     data = Data(self.data_source, tensorizers, Batcher(train_batch_size=16))
     batches = list(data.batches(Stage.TRAIN))
     self.assertEqual(1, len(batches))
     batch = next(iter(batches))
     self.assertEqual({"tokens"}, set(batch))
     tokens, seq_lens = batch["tokens"]
     self.assertEqual((10,), seq_lens.size())
     self.assertEqual(10, len(tokens))
    def test_reset_incremental_states(self):
        """
        This test might seem trivial. However, interacting with the scripted
        sequence generator crosses the Torchscript boundary, which can lead
        to weird behavior. If the incremental states don't get properly
        reset, the model will produce garbage _after_ the first call, which
        is a pain to debug when you only catch it after training.
        """
        tensorizers = get_tensorizers()

        # Avoid numeric issues with quantization by setting a known seed.
        torch.manual_seed(42)

        model = Seq2SeqModel.from_config(
            Seq2SeqModel.Config(
                source_embedding=WordEmbedding.Config(embed_dim=512),
                target_embedding=WordEmbedding.Config(embed_dim=512),
            ),
            tensorizers,
        )

        # Get sample inputs using a data source.
        schema = {
            "source_sequence": str,
            "dict_feat": Gazetteer,
            "target_sequence": str,
        }
        data = Data.from_config(
            Data.Config(source=TSVDataSource.Config(
                train_filename=TEST_FILE_NAME,
                field_names=[
                    "source_sequence", "dict_feat", "target_sequence"
                ],
            )),
            schema,
            tensorizers,
        )
        data.batcher = Batcher(1, 1, 1)
        raw_batch, batch = next(
            iter(data.batches(Stage.TRAIN, load_early=True)))
        inputs = model.arrange_model_inputs(batch)

        model.eval()
        outputs = model(*inputs)
        pred, scores = model.get_pred(outputs, {"stage": Stage.TEST})

        # Verify that the incremental states reset correctly.
        decoder = model.sequence_generator.beam_search.decoder_ens
        decoder.reset_incremental_states()
        self.assertDictEqual(decoder.incremental_states, {"0": {}})

        # Verify that the model returns the same predictions.
        new_pred, new_scores = model.get_pred(outputs, {"stage": Stage.TEST})
        self.assertEqual(new_scores, scores)
Example #7
0
 def test_create_batches(self):
     data = Data(self.data_source, self.tensorizers, Batcher(train_batch_size=16))
     batches = list(data.batches(Stage.TRAIN))
     self.assertEqual(1, len(batches))
     batch = next(iter(batches))
     self.assertEqual(set(self.tensorizers), set(batch))
     tokens, seq_lens = batch["tokens"]
     self.assertEqual((10,), seq_lens.size())
     self.assertEqual((10,), batch["labels"].size())
     self.assertEqual({"tokens", "labels"}, set(batch))
     self.assertEqual(10, len(tokens))
Example #8
0
    def test_create_batches_with_cache(self):
        data = Data(
            self.data_source,
            self.tensorizers,
            Batcher(train_batch_size=1),
            in_memory=True,
        )
        list(data.batches(Stage.TRAIN))
        self.assertEqual(10, len(data.numberized_cache[Stage.TRAIN]))

        data1 = Data(
            self.data_source,
            self.tensorizers,
            Batcher(train_batch_size=1),
            in_memory=True,
        )
        with self.assertRaises(Exception):
            # Concurrent iteration not supported
            batches1 = data1.batches(Stage.TRAIN)
            batches2 = data1.batches(Stage.TRAIN)
            for _ in batches1:
                for _ in batches2:
                    continue
Example #9
0
 def test_create_batches(self):
     data = Data(self.data_source, self.tensorizers,
                 Batcher(train_batch_size=16))
     batches = list(data.batches(Stage.TRAIN))
     self.assertEqual(1, len(batches))
     raw_batch, batch = next(iter(batches))
     self.assertEqual(set(self.tensorizers), set(batch))
     tokens, seq_lens, _ = batch["tokens"]
     self.assertEqual(10, len(raw_batch))
     self.assertEqual({"text", "label", RawExampleFieldName.ROW_INDEX},
                      set(raw_batch[0]))
     self.assertEqual((10, ), seq_lens.size())
     self.assertEqual((10, ), batch["labels"].size())
     self.assertEqual({"tokens", "labels"}, set(batch))
     self.assertEqual(10, len(tokens))
Example #10
0
    def test_force_predictions_on_eval(self):
        tensorizers = get_tensorizers()

        model = Seq2SeqModel.from_config(
            Seq2SeqModel.Config(
                source_embedding=WordEmbedding.Config(embed_dim=512),
                target_embedding=WordEmbedding.Config(embed_dim=512),
            ),
            tensorizers,
        )

        # Get sample inputs using a data source.
        schema = {
            "source_sequence": str,
            "dict_feat": Gazetteer,
            "target_sequence": str,
        }
        data = Data.from_config(
            Data.Config(source=TSVDataSource.Config(
                train_filename=TEST_FILE_NAME,
                field_names=[
                    "source_sequence", "dict_feat", "target_sequence"
                ],
            )),
            schema,
            tensorizers,
        )
        data.batcher = Batcher(1, 1, 1)
        raw_batch, batch = next(
            iter(data.batches(Stage.TRAIN, load_early=True)))
        inputs = model.arrange_model_inputs(batch)

        # Verify that model does not run sequence generation on prediction.
        outputs = model(*inputs)
        pred = model.get_pred(outputs, {"stage": Stage.EVAL})
        self.assertEqual(pred, (None, None))

        # Verify that attempting to set force_eval_predictions is correctly
        # accounted for.
        model.force_eval_predictions = True
        with self.assertRaises(AssertionError):
            _ = model.get_pred(outputs, {"stage": Stage.EVAL})
Example #11
0
 def test_sort(self):
     data = Data(
         self.data_source,
         self.tensorizers,
         Batcher(train_batch_size=16),
         sort_key="tokens",
     )
     batches = list(data.batches(Stage.TRAIN))
     batch = next(iter(batches))
     _, seq_lens, _ = batch["tokens"]
     seq_lens = seq_lens.tolist()
     for i in range(len(seq_lens) - 1):
         self.assertTrue(seq_lens[i] >= seq_lens[i + 1])
     # make sure labels are also in the same order of sorted tokens
     self.assertEqual(
         self.tensorizers["labels"].vocab[batch["labels"][1]],
         "reminder/set_reminder",
     )
     self.assertEqual(self.tensorizers["labels"].vocab[batch["labels"][8]],
                      "alarm/snooze_alarm")