コード例 #1
0
    def test_dump_load(self):
        for strategyfn in self.looping_strategies:
            strategy = strategyfn()
            with self.subTest(
                    msg=
                    f"Checking dump/load for strategy {type(strategy).__name__}"
            ):
                with tempfile.TemporaryDirectory(
                        prefix="LoopingStrategy_DumpPath") as dump_path:
                    dump_path = Path(dump_path) / "samples.p"
                    dataloader = td.LoopingDataGenerator(
                        self.test_set.paths,
                        dg.get_filelist_within_folder,
                        _dummy_dataloader_fn,
                        looping_strategy=strategy)
                    first_epoch = set(s for _, s in zip(
                        range(16), _get_samples_in_epoch(dataloader)))
                    self.assertEqual(len(dataloader), 16 * self.batch_size)
                    try:
                        strategy.dump_content(
                            dump_path
                        )  # TODO: The DL should call dump and load
                    except NotImplementedError:
                        continue

                    strategy = strategyfn()
                    strategy.load_content(dump_path)
                    dataloader = td.LoopingDataGenerator(
                        self.test_set.paths,
                        dg.get_filelist_within_folder,
                        _dummy_dataloader_fn,
                        looping_strategy=strategy)
                    second_epoch = set(s for _, s in zip(
                        range(16), _get_samples_in_epoch(dataloader)))
                    self.assertSetEqual(first_epoch, second_epoch)
コード例 #2
0
 def __create_datagenerator(self, test_mode=False):
     try:
         generator = td.LoopingDataGenerator(
             self.data_source_paths,
             self.data_gather_function,
             self.data_processing_function,
             batch_size=self.batch_size,
             num_validation_samples=self.num_validation_samples,
             num_test_samples=self.num_test_samples,
             split_load_path=self.load_datasets_path,
             split_save_path=self.save_path,
             split_data_root=self.data_root,
             num_workers=self.num_workers,
             cache_path=self.cache_path,
             cache_mode=self.cache_mode,
             looping_strategy=self.looping_strategy,
             save_torch_dataset_path=self.save_torch_dataset_path,
             load_torch_dataset_path=self.load_torch_dataset_path,
             dont_care_num_samples=self.dont_care_num_samples,
             test_mode=test_mode,
             sampler=self.sampler,
             load_test_set_in_training_mode=self.load_test_set_in_training_mode,
         )
     except Exception:
         logger = logging.getLogger(__name__)
         logger.exception("Fatal Error:")
         exit()
     return generator
def load_test_data():
    test_datagenerator = td.LoopingDataGenerator(
        r.get_data_paths_base_0(),
        get_filelist_within_folder_blacklisted,
        dl.get_sensor_bool_dryspot,
        batch_size=512,
        num_validation_samples=131072,
        num_test_samples=1048576,
        split_load_path=r.datasets_dryspots, 
        split_save_path=r.save_path,
        num_workers=75,
        cache_path=r.cache_path,
        cache_mode=td.CachingMode.Both,
        dont_care_num_samples=False,
        test_mode=True)

    test_data = []
    test_labels = []
    test_set = test_datagenerator.get_test_samples()

    for data, labels, _ in test_set:
        test_data.extend(data)
        test_labels.extend(labels)

    test_data = np.array(test_data)
    test_labels = np.array(test_labels)
    test_labels = np.ravel(test_labels)
    print("Loaded Test data.")
    return test_data, test_labels
コード例 #4
0
 def test_noop_strategy(self):
     strategy = ls.NoOpLoopingStrategy()
     dataloader = td.LoopingDataGenerator(self.test_set.paths,
                                          dg.get_filelist_within_folder,
                                          _dummy_dataloader_fn,
                                          looping_strategy=strategy)
     self.assertEqual(len(dataloader), 0)
     first_epoch = set(_get_samples_in_epoch(dataloader))
     self.assertEqual(len(first_epoch), self.test_set.num_samples)
     self.assertEqual(len(dataloader), 0)
     second_epoch = set(_get_samples_in_epoch(dataloader))
     self.assertEqual(len(second_epoch), 0)
コード例 #5
0
 def test_strategies(self):
     for strategyfn in self.looping_strategies:
         strategy = strategyfn()
         with self.subTest(
                 msg=f"Checking strategy {type(strategy).__name__}"):
             dataloader = td.LoopingDataGenerator(
                 self.test_set.paths,
                 dg.get_filelist_within_folder,
                 _dummy_dataloader_fn,
                 looping_strategy=strategy)
             self.assertEqual(len(dataloader), 0)
             first_epoch = set(_get_samples_in_epoch(dataloader))
             self.assertEqual(len(dataloader), self.test_set.num_samples)
             second_epoch = set(_get_samples_in_epoch(dataloader))
             self.assertSetEqual(first_epoch, second_epoch)
コード例 #6
0
    def test_splits_add_up(self):
        with tempfile.TemporaryDirectory(prefix="TorchDG_Splits") as splitpath:
            splitpath = Path(splitpath)
            td.LoopingDataGenerator(
                self.test_set.paths,
                dg.get_filelist_within_folder,
                _dummy_dataloader_fn,
                split_save_path=splitpath,
                num_validation_samples=self.num_validation_samples,
                num_test_samples=self.num_validation_samples)

            split_files = [
                splitpath / f"{name}_set.p"
                for name in ["training", "validation", "test"]
            ]
            for f in split_files:
                self.assertTrue(f.exists(), msg=f"Split file {f} is missing.")

            def load_pickled_filenames(fn):
                with open(fn, "rb") as f:
                    return pickle.load(f)

            splits = [(fn, load_pickled_filenames(fn)) for fn in split_files]

            for a, b in itertools.combinations(splits, 2):
                a_name, a_files = a
                b_name, b_files = b
                self.assertFalse(
                    set(a_files) & set(b_files),
                    msg=f"Files in {a_name} and {b_name} intersect!")

            files_in_splits = sum((files for _, files in splits), [])
            self.assertCountEqual(
                map(str, self.test_set.erf_files),
                files_in_splits,
                msg="Combining splits should result in the original file set!")
import torch

import Resources.training as r
from Pipeline import torch_datagenerator as td
from Pipeline.data_gather import get_filelist_within_folder_blacklisted
from Pipeline.data_loader_dryspot import DataloaderDryspots

if __name__ == "__main__":
    dlds = DataloaderDryspots(divide_by_100k=False)
    batch_size = 131072
    generator = td.LoopingDataGenerator(r.get_data_paths_base_0(),
                                        get_filelist_within_folder_blacklisted,
                                        dlds.get_sensor_bool_dryspot,
                                        num_validation_samples=131072,
                                        num_test_samples=1048576,
                                        batch_size=batch_size,
                                        split_load_path=r.dataset_split,
                                        split_save_path=Path(),
                                        num_workers=75,
                                        looping_strategy=None)
    all_sensor_inputs = []
    for i, (inputs, _, _) in enumerate(generator):
        all_sensor_inputs.append(inputs)
        print(i)
    all_sensor_values = torch.cat(all_sensor_inputs, dim=0)
    _std = all_sensor_values.std(dim=0)
    _mean = all_sensor_values.mean(dim=0)
    print("Std\n", _std)
    print("Mean\n", _mean)
    pickle.dump((_mean, _std), open("mean_std_1140_pressure_sensors.p", "wb"))
if __name__ == "__main__":
    num_samples = 150000

    args = read_cmd_params()
    print("Using ca. 150 000 samples.")

    dl = DataloaderDryspots(sensor_indizes=((0, 1), (0, 1)))
    print("Created Dataloader.")

    generator = td.LoopingDataGenerator(
        r.get_data_paths_base_0(),
        get_filelist_within_folder_blacklisted,
        dl.get_sensor_bool_dryspot,
        batch_size=512,
        num_validation_samples=131072,
        num_test_samples=1048576,
        split_load_path=r.datasets_dryspots, 
        split_save_path=r.save_path,
        num_workers=75,
        cache_path=r.cache_path,
        cache_mode=td.CachingMode.Both,
        dont_care_num_samples=False,
        test_mode=False)
    print("Created Datagenerator")

    train_data = []
    train_labels = []

    for inputs, labels, _ in generator:
        train_data.extend(inputs.numpy())
        train_labels.extend(labels.numpy())
        if len(train_data) > num_samples: