def test_add_weights(self):
        dataset_opt = MockDatasetConfig()
        setattr(dataset_opt, "dataroot", os.path.join(DIR, "temp_dataset"))

        mock_base_dataset = MockBaseDataset(dataset_opt)
        mock_base_dataset.train_dataset.data = Data(y=torch.tensor([1, 1, 1, 0]))
        mock_base_dataset.add_weights()
        self.assertGreater(mock_base_dataset.weight_classes[0], mock_base_dataset.weight_classes[1])

        mock_base_dataset.add_weights(class_weight_method="log")
        print(mock_base_dataset.weight_classes)
        self.assertGreater(mock_base_dataset.weight_classes[0], mock_base_dataset.weight_classes[1])
    def test_normal(self):
        dataset_opt = MockDatasetConfig()
        setattr(dataset_opt, "dataroot", os.path.join(DIR, "temp_dataset"))

        mock_base_dataset = MockBaseDataset(dataset_opt)
        mock_base_dataset.test_dataset = MockDataset()
        model_config = MockModelConfig()
        setattr(model_config, "conv_type", "dense")
        model = MockModel(model_config)

        mock_base_dataset.create_dataloaders(model, 2, True, 0, False)
        datasets = mock_base_dataset.test_dataloaders

        self.assertEqual(len(datasets), 1)
    def test_get_by_name(self):
        dataset_opt = MockDatasetConfig()
        setattr(dataset_opt, "dataroot", os.path.join(DIR, "temp_dataset"))

        mock_base_dataset = MockBaseDataset(dataset_opt)
        mock_base_dataset.test_dataset = [MockDataset(), MockDataset()]
        mock_base_dataset.train_dataset = MockDataset()
        mock_base_dataset.val_dataset = MockDataset()

        for name in ["train", "val", "test_0", "test_1"]:
            self.assertEqual(mock_base_dataset.get_dataset(name).name, name)

        test_with_name = MockDataset()
        setattr(test_with_name, "name", "testos")
        mock_base_dataset.test_dataset = test_with_name
        with self.assertRaises(ValueError):
            mock_base_dataset.get_dataset("test_1")
        mock_base_dataset.get_dataset("testos")

        with self.assertRaises(ValueError):
            mock_base_dataset.test_dataset = [test_with_name, test_with_name]