def train_dataloader(self): ds_train = MultiLabelDataset(folder=self.hparams.data, is_train='train', fname='covid_train_v5.csv', types=self.hparams.types, pathology=self.hparams.pathology, resize=int(self.hparams.shape), balancing=None) ds_train.reset_state() ag_train = [ # imgaug.Albumentations( # AB.SmallestMaxSize(self.hparams.shape, p=1.0)), imgaug.ColorSpace(mode=cv2.COLOR_GRAY2RGB), # imgaug.Affine(shear=10), imgaug.RandomChooseAug([ imgaug.Albumentations(AB.Blur(blur_limit=4, p=0.25)), imgaug.Albumentations(AB.MotionBlur(blur_limit=4, p=0.25)), imgaug.Albumentations(AB.MedianBlur(blur_limit=4, p=0.25)), ]), imgaug.Albumentations(AB.CLAHE(tile_grid_size=(32, 32), p=0.5)), imgaug.RandomOrderAug([ imgaug.Affine(shear=10, border=cv2.BORDER_CONSTANT, interp=cv2.INTER_AREA), imgaug.Affine(translate_frac=(0.01, 0.02), border=cv2.BORDER_CONSTANT, interp=cv2.INTER_AREA), imgaug.Affine(scale=(0.5, 1.0), border=cv2.BORDER_CONSTANT, interp=cv2.INTER_AREA), ]), imgaug.RotationAndCropValid(max_deg=10, interp=cv2.INTER_AREA), imgaug.GoogleNetRandomCropAndResize( crop_area_fraction=(0.8, 1.0), aspect_ratio_range=(0.8, 1.2), interp=cv2.INTER_AREA, target_shape=self.hparams.shape), imgaug.ColorSpace(mode=cv2.COLOR_RGB2GRAY), imgaug.ToFloat32(), ] ds_train = AugmentImageComponent(ds_train, ag_train, 0) # Label smoothing ag_label = [ imgaug.BrightnessScale((0.8, 1.2), clip=False), ] # ds_train = AugmentImageComponent(ds_train, ag_label, 1) ds_train = BatchData(ds_train, self.hparams.batch, remainder=True) if self.hparams.debug: ds_train = FixedSizeData(ds_train, 2) ds_train = MultiProcessRunner(ds_train, num_proc=4, num_prefetch=16) ds_train = PrintData(ds_train) ds_train = MapData( ds_train, lambda dp: [ torch.tensor(np.transpose(dp[0], (0, 3, 1, 2))), torch.tensor(dp[1]).float() ]) return ds_train
def val_dataloader(self): """Summary Returns: TYPE: Description """ ds_valid = MultiLabelDataset(folder=self.hparams.data_path, is_train='valid', fname='valid.csv', types=self.hparams.types, pathology=self.hparams.pathology, resize=int(self.hparams.shape)) ds_valid.reset_state() ag_valid = [ imgaug.Albumentations( AB.SmallestMaxSize(self.hparams.shape, p=1.0)), imgaug.ColorSpace(mode=cv2.COLOR_GRAY2RGB), imgaug.Albumentations(AB.CLAHE(p=1)), imgaug.ToFloat32(), ] ds_valid = AugmentImageComponent(ds_valid, ag_valid, 0) ds_valid = BatchData(ds_valid, self.hparams.batch, remainder=True) # ds_valid = MultiProcessRunner(ds_valid, num_proc=4, num_prefetch=16) ds_valid = PrintData(ds_valid) ds_valid = MapData(ds_valid, lambda dp: [torch.tensor(np.transpose(dp[0], (0, 3, 1, 2))), torch.tensor(dp[1]).float()]) return ds_valid