def setup(self, stage: str): if stage == 'fit' and self.train_dataset is None: tindex = os.path.join(self.hparams.data_dir, 'train_index') eindex = os.path.join(self.hparams.data_dir, 'eval_index') self.train_dataset = LexicalTrainDataset( tindex, self.tokenizer, max_examples=self.hparams.max_train_examples) self.eval_dataset = LexicalTrainDataset( eindex, self.tokenizer, max_examples=self.hparams.max_eval_examples) else: tindex = os.path.join(self.hparams.data_dir, 'test_index') self.test_dataset = LexicalTrainDataset(tindex, self.tokenizer)
def test_collate_batch(self, exists_mock): with patch(DATASET_OPEN_FUNC, get_mock_open_index()): dataset = LexicalTrainDataset('/some/index', BERT_TOKENIZER) examples = [dataset[i] for i in range(len(dataset))] batch = dataset.collate_batch(examples) self.assertIn('input_ids', batch) self.assertIn('token_type_ids', batch) self.assertIn('attention_mask', batch) self.assertIn('next_sentence_label', batch) self.assertIn('mlm_labels', batch) self.assertTupleEqual(batch['input_ids'].shape, (2, 9)) self.assertTupleEqual(batch['token_type_ids'].shape, (2, 9)) self.assertTupleEqual(batch['attention_mask'].shape, (2, 9)) self.assertTupleEqual(batch['next_sentence_label'].shape, (2, )) self.assertTupleEqual(batch['mlm_labels'].shape, (2, 9))
def test_dataset_limit(self, exists_mock): with patch(DATASET_OPEN_FUNC, get_mock_open_index()): dataset = LexicalTrainDataset('/some/index', BERT_TOKENIZER, max_examples=1) # Check loaded examples self.assertEqual(len(dataset), 1) # Check total examples as per index self.assertEqual(dataset.total_examples, 27)
def test_index_creation(self, mock_load, exists_mock): expected_index = [ IndexEntry(0, 9, 'location_a'), IndexEntry(10, 21, 'location_b'), IndexEntry(22, 26, 'location_c'), ] with patch(DATASET_OPEN_FUNC, get_mock_open_index()): dataset = LexicalTrainDataset('/some/index', BERT_TOKENIZER) self.assertEqual(len(dataset.index), 3) self.assertListEqual(dataset.index, expected_index) self.assertListEqual(dataset.examples, [])
def test_dataset_getitem(self, exists_mock): expected_input_ids = torch.arange(1, 10, dtype=torch.long) expected_token_type_ids = torch.tensor([0] * 9, dtype=torch.long) expected_attention_mask = torch.ones(9) expected_is_random_next = torch.tensor([0], dtype=torch.long) with patch(DATASET_OPEN_FUNC, get_mock_open_index()): dataset = LexicalTrainDataset('/some/index', BERT_TOKENIZER) example = dataset[0] self.assertTrue( torch.all(torch.eq(example['input_ids'], expected_input_ids))) self.assertTrue( torch.all( torch.eq(example['token_type_ids'], expected_token_type_ids))) self.assertTrue( torch.all( torch.eq(example['attention_mask'], expected_attention_mask))) self.assertTrue( torch.all( torch.eq(example['next_sentence_label'], expected_is_random_next)))
def test_dataset_len(self, exists_mock): with patch(DATASET_OPEN_FUNC, get_mock_open_index()): dataset = LexicalTrainDataset('/some/index', BERT_TOKENIZER) self.assertEqual(len(dataset), 2)