def test_add_weights(self): dataset_opt = MockDatasetConfig() setattr(dataset_opt, "dataroot", os.path.join(DIR, "temp_dataset")) mock_base_dataset = MockBaseDataset(dataset_opt) mock_base_dataset.train_dataset.data = Data(y=torch.tensor([1, 1, 1, 0])) mock_base_dataset.add_weights() self.assertGreater(mock_base_dataset.weight_classes[0], mock_base_dataset.weight_classes[1]) mock_base_dataset.add_weights(class_weight_method="log") print(mock_base_dataset.weight_classes) self.assertGreater(mock_base_dataset.weight_classes[0], mock_base_dataset.weight_classes[1])
def test_normal(self): dataset_opt = MockDatasetConfig() setattr(dataset_opt, "dataroot", os.path.join(DIR, "temp_dataset")) mock_base_dataset = MockBaseDataset(dataset_opt) mock_base_dataset.test_dataset = MockDataset() model_config = MockModelConfig() setattr(model_config, "conv_type", "dense") model = MockModel(model_config) mock_base_dataset.create_dataloaders(model, 2, True, 0, False) datasets = mock_base_dataset.test_dataloaders self.assertEqual(len(datasets), 1)
def test_get_by_name(self): dataset_opt = MockDatasetConfig() setattr(dataset_opt, "dataroot", os.path.join(DIR, "temp_dataset")) mock_base_dataset = MockBaseDataset(dataset_opt) mock_base_dataset.test_dataset = [MockDataset(), MockDataset()] mock_base_dataset.train_dataset = MockDataset() mock_base_dataset.val_dataset = MockDataset() for name in ["train", "val", "test_0", "test_1"]: self.assertEqual(mock_base_dataset.get_dataset(name).name, name) test_with_name = MockDataset() setattr(test_with_name, "name", "testos") mock_base_dataset.test_dataset = test_with_name with self.assertRaises(ValueError): mock_base_dataset.get_dataset("test_1") mock_base_dataset.get_dataset("testos") with self.assertRaises(ValueError): mock_base_dataset.test_dataset = [test_with_name, test_with_name]