def test_dataloader(self, *args,
                     **kwargs) -> Union[DataLoader, List[DataLoader]]:
     return DataLoader(SResFourierCoefficientDataset(
         self.gt_ds.create_torch_dataset(part='test'),
         amp_min=self.mag_min,
         amp_max=self.mag_max),
                       batch_size=self.batch_size)
 def train_dataloader(self, *args, **kwargs) -> DataLoader:
     return DataLoader(SResFourierCoefficientDataset(
         self.gt_ds.create_torch_dataset(part='train'),
         amp_min=self.mag_min,
         amp_max=self.mag_max),
                       batch_size=self.batch_size,
                       num_workers=1)
 def setup(self, stage: Optional[str] = None):
     tmp_fcds = SResFourierCoefficientDataset(
         self.gt_ds.create_torch_dataset(part='train'),
         amp_min=None,
         amp_max=None)
     self.mag_min = tmp_fcds.amp_min
     self.mag_max = tmp_fcds.amp_max
 def val_dataloader(self, *args,
                    **kwargs) -> Union[DataLoader, List[DataLoader]]:
     return DataLoader(SResFourierCoefficientDataset(
         self.gt_ds.create_torch_dataset(part='validation'),
         amp_min=self.mag_min,
         amp_max=self.mag_max),
                       batch_size=self.batch_size,
                       num_workers=1)