Пример #1
0
def get_config():
    #anchors = np.mgrid[0:4,0:4][:,1:,1:].transpose(1,2,0).reshape((-1,2)) / 4.0
    # prepare dataset
    d1 = dataset.SVHNDigit('train')
    d2 = dataset.SVHNDigit('extra')
    train = RandomMixData([d1, d2])
    test = dataset.SVHNDigit('test')

    augmentors = [
        imgaug.Resize((40, 40)),
        imgaug.BrightnessAdd(30),
        imgaug.Contrast((0.5, 1.5)),
        imgaug.GaussianDeform(  # this is slow
            [(0.2, 0.2), (0.2, 0.8), (0.8, 0.8), (0.8, 0.2)], (40, 40), 0.2,
            3),
    ]
    train = AugmentImageComponent(train, augmentors)
    train = BatchData(train, 128)
    nr_proc = 5
    train = PrefetchData(train, 5, nr_proc)
    step_per_epoch = train.size()

    augmentors = [
        imgaug.Resize((40, 40)),
    ]
    test = AugmentImageComponent(test, augmentors)
    test = BatchData(test, 128, remainder=True)

    sess_config = get_default_sess_config(0.8)

    lr = tf.train.exponential_decay(learning_rate=1e-3,
                                    global_step=get_global_step_var(),
                                    decay_steps=train.size() * 60,
                                    decay_rate=0.2,
                                    staircase=True,
                                    name='learning_rate')
    tf.scalar_summary('learning_rate', lr)

    return TrainConfig(
        dataset=train,
        optimizer=tf.train.AdamOptimizer(lr),
        callbacks=Callbacks([
            StatPrinter(),
            ModelSaver(),
            InferenceRunner(dataset_test,
                            [ScalarStats('cost'),
                             ClassificationError()])
        ]),
        session_config=sess_config,
        model=Model(),
        step_per_epoch=step_per_epoch,
        max_epoch=350,
    )
Пример #2
0
    def test_dataloader(self):
        """Summary

        Returns:
            TYPE: Description
        """
        ds_test = CustomDataSet(folder=self.hparams.data,
                                train_or_valid='test',
                                size=np.inf,
                                hparams=self.hparams)

        ds_test.reset_state()
        ag_test = [
            imgaug.Resize(self.hparams.shape, interp=cv2.INTER_NEAREST),
            imgaug.ToFloat32(),
        ]
        # ds_test = AugmentImageComponent(ds_test, [imgaug.Albumentations(AB.CLAHE(tile_grid_size=(32, 32), always_apply=True, p=1)),], 0)
        ds_test = AugmentImageComponents(ds_test, ag_test, [0, 1])
        ds_test = BatchData(ds_test, self.hparams.batch, remainder=True)
        ds_test = MultiProcessRunner(ds_test, num_proc=4, num_prefetch=16)
        ds_test = PrintData(ds_test)
        ds_test = MapData(
            ds_test, lambda dp: [
                torch.tensor(dp[0][:, np.newaxis, :, :]).float(),
                torch.tensor(dp[1][:, np.newaxis, :, :]).float(),
            ])
        return ds_test
Пример #3
0
    def test_dataloader(self):
        ds_test = MultiLabelDataset(folder=self.hparams.data,
                                    is_train='valid',
                                    fname='covid_test_v5.csv',
                                    types=self.hparams.types,
                                    pathology=self.hparams.pathology,
                                    resize=int(self.hparams.shape),
                                    fold_idx=None,
                                    n_folds=1)

        ds_test.reset_state()
        ag_test = [
            imgaug.Resize(self.hparams.shape, interp=cv2.INTER_AREA),
            imgaug.ToFloat32(),
        ]
        ds_test = AugmentImageComponent(ds_test, ag_test, 0)
        ds_test = BatchData(ds_test, self.hparams.batch, remainder=True)
        ds_test = MultiProcessRunner(ds_test, num_proc=4, num_prefetch=16)
        ds_test = PrintData(ds_test)
        ds_test = MapData(
            ds_test, lambda dp: [
                torch.tensor(np.transpose(dp[0], (0, 3, 1, 2))),
                torch.tensor(dp[1]).float()
            ])
        return ds_test
Пример #4
0
    def val_dataloader(self):
        """Summary

        Returns:
            TYPE: Description
        """
        ds_valid = MultiLabelDataset(
            folder=self.hparams.data,
            is_train='valid',
            fname='covid_test_v5.csv',
            types=self.hparams.types,
            pathology=self.hparams.pathology,
            resize=int(self.hparams.shape),
        )

        ds_valid.reset_state()
        ag_valid = [
            imgaug.Resize(self.hparams.shape, interp=cv2.INTER_AREA),
            imgaug.ToFloat32(),
        ]
        ds_valid = AugmentImageComponent(ds_valid, ag_valid, 0)
        ds_valid = BatchData(ds_valid, self.hparams.batch, remainder=True)
        ds_valid = MultiProcessRunner(ds_valid, num_proc=4, num_prefetch=16)
        ds_valid = PrintData(ds_valid)
        ds_valid = MapData(
            ds_valid, lambda dp: [
                torch.tensor(np.transpose(dp[0], (0, 3, 1, 2))),
                torch.tensor(dp[1]).float()
            ])
        return ds_valid
Пример #5
0
    def val_dataloader(self):
        """Summary

        Returns:
            TYPE: Description
        """
        ds_valid = CustomDataSet(folder=self.hparams.data,
                                 train_or_valid='valid',
                                 size=np.inf,
                                 hparams=self.hparams)

        ds_valid.reset_state()
        ag_valid = [
            imgaug.Resize(self.hparams.shape, interp=cv2.INTER_NEAREST),
            imgaug.ToFloat32(),
        ]
        ds_valid = AugmentImageComponent(ds_valid, [
            imgaug.Albumentations(AB.CLAHE(p=1)),
        ], 0)
        if self.hparams.types == 6:
            ds_valid = AugmentImageComponents(ds_valid, ag_valid,
                                              [0, 1, 2, 3, 4, 5, 6])
        elif self.hparams.types == 1:
            ds_valid = AugmentImageComponents(ds_valid, ag_valid, [0, 1])
        ds_valid = BatchData(ds_valid, self.hparams.batch, remainder=True)
        ds_valid = MultiProcessRunner(ds_valid, num_proc=4, num_prefetch=16)
        ds_valid = PrintData(ds_valid)
        if self.hparams.types == 6:
            ds_valid = MapData(
                ds_valid, lambda dp: [
                    torch.tensor(dp[0][:, np.newaxis, :, :]).float(),
                    torch.tensor(dp[1][:, np.newaxis, :, :]).float(),
                    torch.tensor(dp[2][:, np.newaxis, :, :]).float(),
                    torch.tensor(dp[3][:, np.newaxis, :, :]).float(),
                    torch.tensor(dp[4][:, np.newaxis, :, :]).float(),
                    torch.tensor(dp[5][:, np.newaxis, :, :]).float(),
                    torch.tensor(dp[6][:, np.newaxis, :, :]).float(),
                ])
        elif self.hparams.types == 1:
            ds_valid = MapData(
                ds_valid, lambda dp: [
                    torch.tensor(dp[0][:, np.newaxis, :, :]).float(),
                    torch.tensor(dp[1][:, np.newaxis, :, :]).float(),
                ])
        return ds_valid
Пример #6
0
    def train_dataloader(self):
        """Summary

        Returns:
            TYPE: Description
        """
        ds_train = CustomDataSet(folder=self.hparams.data,
                                 train_or_valid='train',
                                 size=np.inf,
                                 hparams=self.hparams)
        ds_train.reset_state()
        ag_train = [
            imgaug.Affine(shear=10, interp=cv2.INTER_NEAREST),
            imgaug.Affine(translate_frac=(0.01, 0.02),
                          interp=cv2.INTER_NEAREST),
            imgaug.Affine(scale=(0.25, 1.0), interp=cv2.INTER_NEAREST),
            imgaug.RotationAndCropValid(max_deg=10, interp=cv2.INTER_NEAREST),
            imgaug.GoogleNetRandomCropAndResize(
                crop_area_fraction=(0.8, 1.0),
                aspect_ratio_range=(0.8, 1.2),
                interp=cv2.INTER_NEAREST,
                target_shape=self.hparams.shape),
            imgaug.Resize(self.hparams.shape, interp=cv2.INTER_NEAREST),
            imgaug.Flip(horiz=True, vert=False, prob=0.5),
            imgaug.Flip(horiz=False, vert=True, prob=0.5),
            imgaug.Transpose(prob=0.5),
            imgaug.Albumentations(AB.RandomRotate90(p=1)),
            imgaug.ToFloat32(),
        ]
        ds_train = AugmentImageComponent(
            ds_train,
            [
                # imgaug.Float32(),
                # imgaug.RandomChooseAug([
                #     imgaug.Albumentations(AB.IAAAdditiveGaussianNoise(p=0.25)),
                #     imgaug.Albumentations(AB.GaussNoise(p=0.25)),
                #     ]),
                # imgaug.ToUint8(),
                imgaug.RandomChooseAug([
                    imgaug.Albumentations(AB.Blur(blur_limit=4, p=0.25)),
                    imgaug.Albumentations(AB.MotionBlur(blur_limit=4, p=0.25)),
                    imgaug.Albumentations(AB.MedianBlur(blur_limit=4, p=0.25)),
                ]),
                imgaug.RandomChooseAug([
                    # imgaug.Albumentations(AB.IAASharpen(p=0.5)),
                    # imgaug.Albumentations(AB.IAAEmboss(p=0.5)),
                    imgaug.Albumentations(AB.RandomBrightnessContrast(p=0.5)),
                ]),
                imgaug.ToUint8(),
                imgaug.Albumentations(AB.CLAHE(tile_grid_size=(32, 32),
                                               p=0.5)),
            ],
            0)
        ds_train = AugmentImageComponents(ds_train, ag_train, [0, 1])

        ds_train = BatchData(ds_train, self.hparams.batch, remainder=True)
        if self.hparams.debug:
            ds_train = FixedSizeData(ds_train, 2)
        ds_train = MultiProcessRunner(ds_train, num_proc=4, num_prefetch=16)
        ds_train = PrintData(ds_train)

        ds_train = MapData(
            ds_train, lambda dp: [
                torch.tensor(dp[0][:, np.newaxis, :, :]).float(),
                torch.tensor(dp[1][:, np.newaxis, :, :]).float(),
            ])
        return ds_train