Esempio n. 1
0
    def setup(self, stage: str):
        if stage == 'fit' and self.train_dataset is None:
            tindex = os.path.join(self.hparams.data_dir, 'train_index')
            eindex = os.path.join(self.hparams.data_dir, 'eval_index')

            self.train_dataset = LexicalTrainDataset(
                tindex,
                self.tokenizer,
                max_examples=self.hparams.max_train_examples)

            self.eval_dataset = LexicalTrainDataset(
                eindex,
                self.tokenizer,
                max_examples=self.hparams.max_eval_examples)
        else:
            tindex = os.path.join(self.hparams.data_dir, 'test_index')
            self.test_dataset = LexicalTrainDataset(tindex, self.tokenizer)
    def test_collate_batch(self, exists_mock):
        with patch(DATASET_OPEN_FUNC, get_mock_open_index()):
            dataset = LexicalTrainDataset('/some/index', BERT_TOKENIZER)
            examples = [dataset[i] for i in range(len(dataset))]

            batch = dataset.collate_batch(examples)

            self.assertIn('input_ids', batch)
            self.assertIn('token_type_ids', batch)
            self.assertIn('attention_mask', batch)
            self.assertIn('next_sentence_label', batch)
            self.assertIn('mlm_labels', batch)

            self.assertTupleEqual(batch['input_ids'].shape, (2, 9))
            self.assertTupleEqual(batch['token_type_ids'].shape, (2, 9))
            self.assertTupleEqual(batch['attention_mask'].shape, (2, 9))
            self.assertTupleEqual(batch['next_sentence_label'].shape, (2, ))
            self.assertTupleEqual(batch['mlm_labels'].shape, (2, 9))
    def test_dataset_limit(self, exists_mock):
        with patch(DATASET_OPEN_FUNC, get_mock_open_index()):
            dataset = LexicalTrainDataset('/some/index',
                                          BERT_TOKENIZER,
                                          max_examples=1)

            # Check loaded examples
            self.assertEqual(len(dataset), 1)

            # Check total examples as per index
            self.assertEqual(dataset.total_examples, 27)
    def test_index_creation(self, mock_load, exists_mock):
        expected_index = [
            IndexEntry(0, 9, 'location_a'),
            IndexEntry(10, 21, 'location_b'),
            IndexEntry(22, 26, 'location_c'),
        ]

        with patch(DATASET_OPEN_FUNC, get_mock_open_index()):
            dataset = LexicalTrainDataset('/some/index', BERT_TOKENIZER)

            self.assertEqual(len(dataset.index), 3)
            self.assertListEqual(dataset.index, expected_index)
            self.assertListEqual(dataset.examples, [])
    def test_dataset_getitem(self, exists_mock):
        expected_input_ids = torch.arange(1, 10, dtype=torch.long)
        expected_token_type_ids = torch.tensor([0] * 9, dtype=torch.long)
        expected_attention_mask = torch.ones(9)
        expected_is_random_next = torch.tensor([0], dtype=torch.long)

        with patch(DATASET_OPEN_FUNC, get_mock_open_index()):
            dataset = LexicalTrainDataset('/some/index', BERT_TOKENIZER)
            example = dataset[0]

            self.assertTrue(
                torch.all(torch.eq(example['input_ids'], expected_input_ids)))
            self.assertTrue(
                torch.all(
                    torch.eq(example['token_type_ids'],
                             expected_token_type_ids)))
            self.assertTrue(
                torch.all(
                    torch.eq(example['attention_mask'],
                             expected_attention_mask)))
            self.assertTrue(
                torch.all(
                    torch.eq(example['next_sentence_label'],
                             expected_is_random_next)))
    def test_dataset_len(self, exists_mock):
        with patch(DATASET_OPEN_FUNC, get_mock_open_index()):
            dataset = LexicalTrainDataset('/some/index', BERT_TOKENIZER)

            self.assertEqual(len(dataset), 2)