Ejemplo n.º 1
0
def common_aug(mode, params):
    '''
    :param mode: 'train', 'test', 'inference'
    :param params:
    '''
    #aug_params = params.get('augm_params', dict())
    augs_list = []
    assert mode  in {'train', 'debug', 'inference'}
    if mode == 'train':
        augs_list.append(albumentations.PadIfNeeded(min_height=params.data.net_hw[0], min_width=params.data.net_hw[1],
                                                    border_mode=cv2.BORDER_REPLICATE,
                                                    always_apply=True))
        augs_list.append(albumentations.RandomCrop(height=params.data.net_hw[0], width=params.data.net_hw[1], always_apply=True))
        if params.augmentation.rotate_limit:
            augs_list.append(T.Rotate(limit=params.augmentation.rotate_limit, border_mode=cv2.BORDER_CONSTANT, always_apply=True))
        # augs_list.append(T.OpticalDistortion(border_mode=cv2.BORDER_CONSTANT)) - can't handle boundboxes
    elif mode == 'debug':
        augs_list.append(albumentations.CenterCrop(height=params.data.net_hw[0], width=params.data.net_hw[1], always_apply=True))
    if mode != 'inference':
        if params.augmentation.get('blur_limit', 4):
            augs_list.append(T.Blur(blur_limit=params.augmentation.get('blur_limit', 4)))
        if params.augmentation.get('RandomBrightnessContrast', True):
            augs_list.append(T.RandomBrightnessContrast())
        #augs_list.append(T.MotionBlur())
        if params.augmentation.get('JpegCompression', True):
            augs_list.append(T.JpegCompression(quality_lower=30, quality_upper=100))
        #augs_list.append(T.VerticalFlip())
        if params.augmentation.get('HorizontalFlip', True):
            augs_list.append(T.HorizontalFlip())

    return albumentations.ReplayCompose(augs_list, p=1., bbox_params = {'format':'albumentations', 'min_visibility':0.5})
Ejemplo n.º 2
0
    def __init__(
        self,
        prob=0.7,
        blur_prob=0.7,
        jitter_prob=0.7,
        rotate_prob=0.7,
        flip_prob=0.7,
    ):
        super().__init__()

        self.prob = prob
        self.blur_prob = blur_prob
        self.jitter_prob = jitter_prob
        self.rotate_prob = rotate_prob
        self.flip_prob = flip_prob

        self.transforms = al.Compose(
            [
                transforms.RandomRotate90(),
                transforms.Flip(),
                transforms.HueSaturationValue(),
                transforms.RandomBrightnessContrast(),
                transforms.Transpose(),
                OneOf([
                    transforms.RandomCrop(220, 220, p=0.5),
                    transforms.CenterCrop(220, 220, p=0.5)
                ],
                      p=0.5),
                # transforms.Resize(352,352),
                # transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ],
            p=self.prob)
Ejemplo n.º 3
0
def augmentation(mode, target_size, prob=0.5, aug_m=2):
    '''
    description: augmentation
    mode: 'train' 'test'
    target_size: int or list, the shape of image ,
    aug_m: Strength of transform
    '''
    high_p = prob
    low_p = high_p / 2.0
    M = aug_m
    first_size = [int(x / 0.7) for x in target_size]

    if mode == 'train':
        return composition.Compose([
            transforms.Resize(first_size[0], first_size[1], interpolation=3),
            transforms.Flip(p=0.5),
            composition.OneOf([
                RandomCenterCut(scale=0.1 * M),
                transforms.ShiftScaleRotate(shift_limit=0.05 * M,
                                            scale_limit=0.1 * M,
                                            rotate_limit=180,
                                            border_mode=cv2.BORDER_CONSTANT,
                                            value=0),
                albumentations.imgaug.transforms.IAAAffine(
                    shear=(-10 * M, 10 * M), mode='constant')
            ],
                              p=high_p),
            transforms.RandomBrightnessContrast(
                brightness_limit=0.1 * M, contrast_limit=0.03 * M, p=high_p),
            transforms.HueSaturationValue(hue_shift_limit=5 * M,
                                          sat_shift_limit=15 * M,
                                          val_shift_limit=10 * M,
                                          p=high_p),
            transforms.OpticalDistortion(distort_limit=0.03 * M,
                                         shift_limit=0,
                                         border_mode=cv2.BORDER_CONSTANT,
                                         value=0,
                                         p=low_p),
            composition.OneOf([
                transforms.Blur(blur_limit=7),
                albumentations.imgaug.transforms.IAASharpen(),
                transforms.GaussNoise(var_limit=(2.0, 10.0), mean=0),
                transforms.ISONoise()
            ],
                              p=low_p),
            transforms.Resize(target_size[0], target_size[1], interpolation=3),
            transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                 std=(0.5, 0.5, 0.5),
                                 max_pixel_value=255.0)
        ],
                                   p=1)

    else:
        return composition.Compose([
            transforms.Resize(target_size[0], target_size[1], interpolation=3),
            transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                 std=(0.5, 0.5, 0.5),
                                 max_pixel_value=255.0)
        ],
                                   p=1)
Ejemplo n.º 4
0
def RandomBrightnessContrast(brightness_limit=0.2,
                             contrast_limit=0.2,
                             brightness_by_max=True,
                             p=0.5):
    return transforms.RandomBrightnessContrast(
        brightness_limit=brightness_limit,
        contrast_limit=contrast_limit,
        brightness_by_max=brightness_by_max,
        p=p)
Ejemplo n.º 5
0
def get_presize_combine_transforms_V4():
    transforms_presize = A.Compose([
        transforms.PadIfNeeded(600, 800),
        geometric.Perspective(
            scale=[0, .1],
            pad_mode=cv2.BORDER_REFLECT,
            interpolation=cv2.INTER_AREA, p = .3),
        transforms.Flip(),
        geometric.ShiftScaleRotate(interpolation=cv2.INTER_LANCZOS4, p = 0.95, scale_limit=0.0),
        crops.RandomResizedCrop(
            TARGET_SIZE, TARGET_SIZE,
            scale=(config['rrc_scale_min'], config['rrc_scale_max']),
            ratio=(.70, 1.4),
            interpolation=cv2.INTER_CUBIC,
            p=1.0),
        transforms.Transpose()
        #rotate.Rotate(interpolation=cv2.INTER_LANCZOS4, p = 0.99),
    ])
    
    transforms_postsize = A.Compose([
        #imgaug.IAAPiecewiseAffine(),

        transforms.CoarseDropout(),
        transforms.CLAHE(p=.1),
        transforms.RandomToneCurve(scale=.1, p=0.2),
        transforms.RandomBrightnessContrast(
            brightness_limit=.1, 
            contrast_limit=0.4,
            p=.8),
        transforms.HueSaturationValue(
            hue_shift_limit=20, 
            sat_shift_limit=50,
            val_shift_limit=0, 
            p=0.5),
        transforms.Equalize(p=0.05),
        transforms.FancyPCA(p=0.05),
        transforms.RandomGridShuffle(p=0.1),
        A.OneOf([
                transforms.MotionBlur(blur_limit=(3, 9)),
                transforms.GaussianBlur(),
                transforms.MedianBlur()
            ], p=0.1),
        transforms.ISONoise(p=.2),
        transforms.GaussNoise(var_limit=127., p=.3),
        A.OneOf([
            transforms.GridDistortion(interpolation=cv2.INTER_AREA, distort_limit=[0.7, 0.7], p=0.5),
            transforms.OpticalDistortion(interpolation=cv2.INTER_AREA, p=.3),
        ], p=.3),
        geometric.ElasticTransform(alpha=4, sigma=4, alpha_affine=4, interpolation=cv2.INTER_AREA, p=0.3),
        transforms.CoarseDropout(),
        transforms.Normalize(),
        ToTensorV2()
    ])
    return transforms_presize, transforms_postsize
Ejemplo n.º 6
0
 def __init__(self, brightness_coeff: float = -1) -> None:
     """
     Initialize the brightener
     :param brightness_coeff: coefficient between 0 and 1 that controls brightness,
     if brightness_coeff is -1, the value is drawn from uniform distribution on [0.0, 1.0) upon each application
     """
     self.brightness_coeff = brightness_coeff
     brightness_coeff_lower, brightness_coeff_upper = brightness_coeff, brightness_coeff
     if self.brightness_coeff == -1:
         brightness_coeff_lower, brightness_coeff_upper = 0.0, 1.0
     self.brighten_object = albu.RandomBrightnessContrast(
         (brightness_coeff_lower, brightness_coeff_upper), (0.0, 0.0),
         always_apply=True)
Ejemplo n.º 7
0
def get_presize_combine_tune_transforms():
    transforms_presize = A.Compose([
        transforms.Transpose(),
        transforms.Flip(),
        #transforms.PadIfNeeded(600, 800),
        crops.RandomResizedCrop(
            TARGET_SIZE, TARGET_SIZE,
            scale=(.75, 1),
            interpolation=cv2.INTER_CUBIC,
            p=1.0),
        rotate.Rotate(interpolation=cv2.INTER_LANCZOS4, p = 0.99),
    ])
    
    transforms_postsize = A.Compose([
        transforms.CoarseDropout(),
        # transforms.CLAHE(p=.1),
        transforms.RandomToneCurve(scale=.1),
        transforms.RandomBrightnessContrast(
            brightness_limit=.1, 
            contrast_limit=0.2,
            p=.7),
        transforms.HueSaturationValue(
            hue_shift_limit=20, 
            sat_shift_limit=60,
            val_shift_limit=0, 
            p=0.6),
        #transforms.Equalize(p=0.1),
        #transforms.FancyPCA(p=0.05),
        #transforms.RandomGridShuffle(p=0.1),
        #A.OneOf([
        #        transforms.MotionBlur(blur_limit=(3, 9)),
        #        transforms.GaussianBlur(),
        #        transforms.MedianBlur()
        #    ], p=0.2),
        transforms.ISONoise(p=.3),
        transforms.GaussNoise(var_limit=255., p=.3),
        #A.OneOf([
        #     transforms.GridDistortion(interpolation=cv2.INTER_AREA, distort_limit=[0.7, 0.7], p=0.5),
        #     transforms.OpticalDistortion(interpolation=cv2.INTER_AREA, p=.3),
        # ], p=.3),
        geometric.ElasticTransform(alpha=4, sigma=100, alpha_affine=100, interpolation=cv2.INTER_AREA, p=0.3),
        transforms.CoarseDropout(),
        transforms.Normalize(),
        ToTensorV2()
    ])
    return transforms_presize, transforms_postsize
Ejemplo n.º 8
0
def get_augmentations():
    """Get a list of 'major' and 'minor' augmentation functions for the pipeline in a dictionary."""
    return {
        "major": {
            "shift-scale-rot":
            trans.ShiftScaleRotate(
                shift_limit=0.05,
                rotate_limit=35,
                border_mode=cv2.BORDER_REPLICATE,
                always_apply=True,
            ),
            "crop":
            trans.RandomResizedCrop(100,
                                    100,
                                    scale=(0.8, 0.95),
                                    ratio=(0.8, 1.2),
                                    always_apply=True),
            # "elastic": trans.ElasticTransform(
            #     alpha=0.8,
            #     alpha_affine=10,
            #     sigma=40,
            #     border_mode=cv2.BORDER_REPLICATE,
            #     always_apply=True,
            # ),
            "distort":
            trans.OpticalDistortion(0.2, always_apply=True),
        },
        "minor": {
            "blur":
            trans.GaussianBlur(7, always_apply=True),
            "noise":
            trans.GaussNoise((20.0, 40.0), always_apply=True),
            "bright-contrast":
            trans.RandomBrightnessContrast(0.4, 0.4, always_apply=True),
            "hsv":
            trans.HueSaturationValue(30, 40, 50, always_apply=True),
            "rgb":
            trans.RGBShift(always_apply=True),
            "flip":
            trans.HorizontalFlip(always_apply=True),
        },
    }
Ejemplo n.º 9
0
def get_train_transforms():
    return A.Compose([
        transforms.PadIfNeeded(600, 800),
        geometric.ShiftScaleRotate(interpolation=cv2.INTER_LANCZOS4, p = 0.99, scale_limit=0.8),
        geometric.Perspective(pad_mode=cv2.BORDER_REFLECT,interpolation=cv2.INTER_AREA),
        crops.RandomResizedCrop(
            TARGET_SIZE, TARGET_SIZE,
            scale=(config['rrc_scale_min'], config['rrc_scale_max']),
            interpolation=cv2.INTER_CUBIC,
            p=1.0),
        transforms.Transpose(),
        transforms.Flip(),
        transforms.CoarseDropout(),
        transforms.CLAHE(p=.1),
        transforms.RandomToneCurve(scale=.1),
        transforms.RandomBrightnessContrast(
            brightness_limit=.1, 
            contrast_limit=0.3,
            p=.7),
        transforms.HueSaturationValue(
            hue_shift_limit=20, 
            sat_shift_limit=60,
            val_shift_limit=0, 
            p=0.6),
        transforms.RandomGridShuffle(p=0.1),
        A.OneOf([
                transforms.MotionBlur(blur_limit=(3, 9)),
                transforms.GaussianBlur(),
                transforms.MedianBlur()
            ], p=0.2),
        transforms.ISONoise(p=.3),
        transforms.GaussNoise(var_limit=255., p=.3),
        A.OneOf([
            transforms.GridDistortion(interpolation=cv2.INTER_AREA, distort_limit=[0.7, 0.7], p=0.5),
            transforms.OpticalDistortion(interpolation=cv2.INTER_AREA, p=.3),
        ], p=.3),
        geometric.ElasticTransform(alpha=4, sigma=100, alpha_affine=100, interpolation=cv2.INTER_AREA, p=0.3),
        transforms.CoarseDropout(),
        transforms.Normalize(),
        ToTensorV2()
    ])
 def train_dataloader(self):
     augmentations = Compose(
         [
             A.RandomResizedCrop(
                 height=self.hparams.sz,
                 width=self.hparams.sz,
                 scale=(0.7, 1.0),
             ),
             # AdvancedHairAugmentation(),
             A.GridDistortion(),
             A.RandomBrightnessContrast(),
             A.ShiftScaleRotate(),
             A.Flip(p=0.5),
             A.CoarseDropout(
                 max_height=int(self.hparams.sz / 10),
                 max_width=int(self.hparams.sz / 10),
             ),
             # A.HueSaturationValue(),
             A.Normalize(
                 mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225],
                 max_pixel_value=255,
             ),
             ToTensorV2(),
         ]
     )
     train_ds = MelanomaDataset(
         df=self.train_df,
         images_path=self.train_images_path,
         augmentations=augmentations,
         train_or_valid=True,
     )
     return DataLoader(
         train_ds,
         # sampler=sampler,
         batch_size=self.hparams.bs,
         shuffle=True,
         num_workers=os.cpu_count(),
         pin_memory=True,
     )
def get_tta_transforms():
    return Compose([
        A.RandomResizedCrop(
            height=hparams.sz,
            width=hparams.sz,
            scale=(0.7, 1.0),
        ),
        # AdvancedHairAugmentation(),
        A.GridDistortion(),
        A.RandomBrightnessContrast(),
        A.ShiftScaleRotate(),
        A.Flip(p=0.5),
        A.CoarseDropout(
            max_height=int(hparams.sz / 10),
            max_width=int(hparams.sz / 10),
        ),
        # A.HueSaturationValue(),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255,
        ),
        ToTensorV2(),
    ])
Ejemplo n.º 12
0
    train_img_paths = []
    train_mask_paths = []
    train_data_path = ["data/kvasir-seg/TrainDataset"]
    for i in train_data_path:
        train_img_paths.extend(glob(os.path.join(i, "images", "*")))
        train_mask_paths.extend(glob(os.path.join(i, "masks", "*")))
    train_img_paths.sort()
    train_mask_paths.sort()

    transforms = al.Compose(
        [
            transforms.RandomRotate90(),
            transforms.Flip(),
            transforms.HueSaturationValue(),
            transforms.RandomBrightnessContrast(),
            transforms.Transpose(),
            OneOf(
                [
                    transforms.RandomCrop(220, 220, p=0.5),
                    transforms.CenterCrop(220, 220, p=0.5),
                ],
                p=1,
            ),
            # transforms.Resize(352,352),
            # transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ],
        p=1,
    )
    dataset = KvasirDataset(train_img_paths,
                            train_mask_paths,
Ejemplo n.º 13
0
def main():
    # config = vars(parse_args_func())

    #config_file = "../configs/config_v1.json"
    args = vars(parse_args_func())
    config_file = args['config']
    config_dict = json.loads(open(config_file, 'rt').read())
    # config_dict = json.loads(open(sys.argv[1], 'rt').read())

    file_dict = config_dict['file_path']
    config = config_dict['opt_config']

    input_folder = file_dict['input_path']  # '../inputs'
    checkpoint_folder = file_dict['checkpoint_path']  # '../checkpoint'
    model_folder = file_dict['model_path']  # '../models'

    if 'False' in config['deep_supervision']:
        config['deep_supervision'] = False
    else:
        config['deep_supervision'] = True

    if 'False' in config['nesterov']:
        config['nesterov'] = False
    else:
        config['nesterov'] = True

    if 'None' in config['name']:
        config['name'] = None

    if config['name'] is None:
        config['name'] = '%s_%s_segmodel' % (config['dataset'], config['arch'])
    os.makedirs(os.path.join(model_folder, '%s' % config['name']),
                exist_ok=True)

    print('-' * 20)
    for key in config:
        print('%s: %s' % (key, config[key]))
    print('-' * 20)

    with open(os.path.join(model_folder, '%s/config.yml' % config['name']),
              'w') as f:
        yaml.dump(config, f)

    # define loss function (criterion)
    if config['loss'] == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion = losses.__dict__[config['loss']]().cuda()

    cudnn.benchmark = True

    # create model

    if 'False' in config['resume']:
        config['resume'] = False
    else:
        config['resume'] = True

    # Data loading code
    img_ids = glob(
        os.path.join(input_folder, config['dataset'], 'images', 'training',
                     '*' + config['img_ext']))
    train_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    img_ids = glob(
        os.path.join(input_folder, config['val_dataset'], 'images',
                     'validation', '*' + config['img_ext']))
    val_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    img_ids = glob(
        os.path.join(input_folder, config['val_dataset'], 'images', 'test',
                     '*' + config['img_ext']))
    test_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    train_transform = Compose([
        # transforms.RandomScale ([config['scale_min'], config['scale_max']]),
        # transforms.RandomRotate90(),
        transforms.Rotate([config['rotate_min'], config['rotate_max']],
                          value=mean,
                          mask_value=0),
        # transforms.GaussianBlur (),
        transforms.Flip(),
        # transforms.HorizontalFlip (),
        transforms.HueSaturationValue(hue_shift_limit=10,
                                      sat_shift_limit=10,
                                      val_shift_limit=10),
        transforms.RandomBrightnessContrast(brightness_limit=0.10,
                                            contrast_limit=0.10,
                                            brightness_by_max=True),
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(mean=mean, std=std),
    ])

    val_transform = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(mean=mean, std=std),
    ])

    train_dataset = Dataset(img_ids=train_img_ids,
                            img_dir=os.path.join(input_folder,
                                                 config['dataset'], 'images',
                                                 'training'),
                            mask_dir=os.path.join(input_folder,
                                                  config['dataset'],
                                                  'annotations', 'training'),
                            img_ext=config['img_ext'],
                            mask_ext=config['mask_ext'],
                            num_classes=config['num_classes'],
                            input_channels=config['input_channels'],
                            transform=train_transform)
    val_dataset = Dataset(img_ids=val_img_ids,
                          img_dir=os.path.join(input_folder, config['dataset'],
                                               'images', 'validation'),
                          mask_dir=os.path.join(input_folder,
                                                config['dataset'],
                                                'annotations', 'validation'),
                          img_ext=config['img_ext'],
                          mask_ext=config['mask_ext'],
                          num_classes=config['num_classes'],
                          input_channels=config['input_channels'],
                          transform=val_transform)
    test_dataset = Dataset(img_ids=val_img_ids,
                           img_dir=os.path.join(input_folder,
                                                config['dataset'], 'images',
                                                'test'),
                           mask_dir=os.path.join(input_folder,
                                                 config['dataset'],
                                                 'annotations', 'test'),
                           img_ext=config['img_ext'],
                           mask_ext=config['mask_ext'],
                           num_classes=config['num_classes'],
                           input_channels=config['input_channels'],
                           transform=val_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,  # config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=1,  # config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False)

    log = OrderedDict([
        ('epoch', []),
        ('lr', []),
        ('loss', []),
        ('iou', []),
        ('dice', []),
        ('val_loss', []),
        ('val_iou', []),
        ('val_dice', []),
    ])
    if not os.path.isdir(checkpoint_folder):
        os.mkdir(checkpoint_folder)
    # create generator model
    #val_config = config_dict['config']
    generator_name = config['generator_name']
    with open(os.path.join(model_folder, '%s/config.yml' % generator_name),
              'r') as f:
        g_config = yaml.load(f, Loader=yaml.FullLoader)
    generator = Generator(g_config)
    generator.initialize_with_srresnet(model_folder, g_config)
    lr = config['gan_lr']
    # Initialize generator's optimizer
    optimizer_g = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                 generator.parameters()),
                                   lr=lr)
    #params = filter(lambda p: p.requires_grad, generator.parameters())
    #optimizer_g, scheduler_g = optimizer_scheduler(params, config)

    # Discriminator
    # Discriminator parameters
    num_classes = config['num_classes']
    kernel_size_d = 3  # kernel size in all convolutional blocks
    n_channels_d = 64  # number of output channels in the first convolutional block, after which it is doubled in every 2nd block thereafter
    n_blocks_d = 8  # number of convolutional blocks
    fc_size_d = 1024  # size of the first fully connected layer
    discriminator = Discriminator(num_classes,
                                  kernel_size=kernel_size_d,
                                  n_channels=n_channels_d,
                                  n_blocks=n_blocks_d,
                                  fc_size=fc_size_d)
    # Initialize discriminator's optimizer
    optimizer_d = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                 discriminator.parameters()),
                                   lr=lr)
    #params = filter(lambda p: p.requires_grad, discriminator.parameters())
    #optimizer_d, scheduler_d = optimizer_scheduler(params, config)
    adversarial_loss_criterion = nn.BCEWithLogitsLoss()
    content_loss_criterion = nn.MSELoss()

    generator = generator.cuda()
    discriminator = discriminator.cuda()
    #truncated_vgg19 = truncated_vgg19.cuda()
    content_loss_criterion = content_loss_criterion.cuda()
    adversarial_loss_criterion = adversarial_loss_criterion.cuda()

    generator = torch.nn.DataParallel(generator)
    discriminator = torch.nn.DataParallel(discriminator)

    if not os.path.isdir(checkpoint_folder):
        os.mkdir(checkpoint_folder)
    log_name = config['name']
    log_dir = os.path.join(checkpoint_folder, log_name)
    writer = SummaryWriter(logdir=log_dir)

    best_iou = 0
    trigger = 0
    Best_dice = 0
    iou_AtBestDice = 0
    start_epoch = 0
    for epoch in range(start_epoch, config['epochs']):
        print('Epoch [%d/%d]' % (epoch, config['epochs']))

        # train for one epoch
        train_log = train(epoch, config, train_loader, generator,
                          discriminator, criterion, adversarial_loss_criterion,
                          content_loss_criterion, optimizer_g, optimizer_d)

        # evaluate on validation set
        val_log = validate(config, val_loader, generator, criterion)
        test_log = validate(config, test_loader, generator, criterion)

        if Best_dice < test_log['dice']:
            Best_dice = test_log['dice']
            iou_AtBestDice = test_log['iou']
        print(
            'loss %.4f - iou %.4f - dice %.4f - val_loss %.4f - val_iou %.4f - val_dice %.4f - test_iou %.4f - test_dice %.4f - Best_dice %.4f - iou_AtBestDice %.4f'
            % (train_log['loss'], train_log['iou'], train_log['dice'],
               val_log['loss'], val_log['iou'], val_log['dice'],
               test_log['iou'], test_log['dice'], Best_dice, iou_AtBestDice))

        save_tensorboard(writer, train_log, val_log, test_log, epoch)
        log['epoch'].append(epoch)
        log['lr'].append(config['lr'])
        log['loss'].append(train_log['loss'])
        log['iou'].append(train_log['iou'])
        log['dice'].append(train_log['dice'])
        log['val_loss'].append(val_log['loss'])
        log['val_iou'].append(val_log['iou'])
        log['val_dice'].append(val_log['dice'])

        pd.DataFrame(log).to_csv(os.path.join(model_folder,
                                              '%s/log.csv' % config['name']),
                                 index=False)
        trigger += 1

        if test_log['iou'] > best_iou:
            torch.save(
                generator.state_dict(),
                os.path.join(model_folder, '%s/model.pth' % config['name']))
            best_iou = test_log['iou']
            print("=> saved best model")
            trigger = 0

        # early stopping
        if config['early_stopping'] >= 0 and trigger >= config[
                'early_stopping']:
            print("=> early stopping")
            break

        torch.cuda.empty_cache()
Ejemplo n.º 14
0
# axs[0].imshow(image)
# axs[0].set_title("Original",fontsize=30)

# image_ = transforms.RandomRotate90(always_apply=True)(image=image)["image"]
# axs[1].imshow(image_)
# axs[1].set_title("Rotate",fontsize=30)

# image_ = transforms.Flip(always_apply=True)(image=image)["image"]
# axs[2].imshow(image_)
# axs[2].set_title("Flip",fontsize=30)

image_ = transforms.HueSaturationValue(always_apply=True)(image=image)["image"]
axs[0].imshow(image_)
axs[0].set_title("Saturation", fontsize=30)

image_ = transforms.RandomBrightnessContrast(always_apply=True)(
    image=image)["image"]
axs[1].imshow(image_)
axs[1].set_title("Brightness", fontsize=30)

image_ = OneOf(
    [
        transforms.RandomCrop(204, 250, p=1),
        transforms.CenterCrop(204, 250, p=1),
    ],
    p=1,
)(image=image)["image"]
axs[2].imshow(image_)
axs[2].set_title("Random Crop", fontsize=30)
def augment(im, params=None):
    """
    Perform data augmentation on some image using the albumentations package.

    Parameters
    ----------
    im : Numpy array
    params : dict or None
        Contains the data augmentation parameters
        Mandatory keys:
        - h_flip ([0,1] float): probability of performing an horizontal left-right mirroring.
        - v_flip ([0,1] float): probability of performing an vertical up-down mirroring.
        - rot ([0,1] float):  probability of performing a rotation to the image.
        - rot_lim (int):  max degrees of rotation.
        - stretch ([0,1] float):  probability of randomly stretching an image.
        - crop ([0,1] float): randomly take an image crop.
        - zoom ([0,1] float): random zoom applied to crop_size.
            --> Therefore the effective crop size at each iteration will be a
                random number between 1 and crop*(1-zoom). For example:
                  * crop=1, zoom=0: no crop of the image
                  * crop=1, zoom=0.1: random crop of random size between 100% image and 90% of the image
                  * crop=0.9, zoom=0.1: random crop of random size between 90% image and 80% of the image
                  * crop=0.9, zoom=0: random crop of always 90% of the image
                  Image size refers to the size of the shortest side.
        - blur ([0,1] float):  probability of randomly blurring an image.
        - pixel_noise ([0,1] float):  probability of randomly adding pixel noise to an image.
        - pixel_sat ([0,1] float):  probability of randomly using HueSaturationValue in the image.
        - cutout ([0,1] float):  probability of using cutout in the image.

    Returns
    -------
    Numpy array
    """

    ## 1) Crop the image
    effective_zoom = np.random.rand() * params['zoom']
    crop = params['crop'] - effective_zoom

    ly, lx, channels = im.shape
    crop_size = int(crop * min([ly, lx]))
    rand_x = np.random.randint(low=0, high=lx - crop_size + 1)
    rand_y = np.random.randint(low=0, high=ly - crop_size + 1)

    crop = transforms.Crop(x_min=rand_x,
                           y_min=rand_y,
                           x_max=rand_x + crop_size,
                           y_max=rand_y + crop_size)

    im = crop(image=im)['image']

    ## 2) Now add the transformations for augmenting the image pixels
    transform_list = []

    # Add random stretching
    if params['stretch']:
        transform_list.append(
            imgaug_transforms.IAAPerspective(scale=0.1, p=params['stretch']))

    # Add random rotation
    if params['rot']:
        transform_list.append(
            transforms.Rotate(limit=params['rot_lim'], p=params['rot']))

    # Add horizontal flip
    if params['h_flip']:
        transform_list.append(transforms.HorizontalFlip(p=params['h_flip']))

    # Add vertical flip
    if params['v_flip']:
        transform_list.append(transforms.VerticalFlip(p=params['v_flip']))

    # Add some blur to the image
    if params['blur']:
        transform_list.append(
            albumentations.OneOf([
                transforms.MotionBlur(blur_limit=7, p=1.),
                transforms.MedianBlur(blur_limit=7, p=1.),
                transforms.Blur(blur_limit=7, p=1.),
            ],
                                 p=params['blur']))

    # Add pixel noise
    if params['pixel_noise']:
        transform_list.append(
            albumentations.OneOf(
                [
                    transforms.CLAHE(clip_limit=2, p=1.),
                    imgaug_transforms.IAASharpen(p=1.),
                    imgaug_transforms.IAAEmboss(p=1.),
                    transforms.RandomBrightnessContrast(contrast_limit=0,
                                                        p=1.),
                    transforms.RandomBrightnessContrast(brightness_limit=0,
                                                        p=1.),
                    transforms.RGBShift(p=1.),
                    transforms.RandomGamma(p=1.)  #,
                    # transforms.JpegCompression(),
                    # transforms.ChannelShuffle(),
                    # transforms.ToGray()
                ],
                p=params['pixel_noise']))

    # Add pixel saturation
    if params['pixel_sat']:
        transform_list.append(
            transforms.HueSaturationValue(p=params['pixel_sat']))

    # Remove randomly remove some regions from the image
    if params['cutout']:
        ly, lx, channels = im.shape
        scale_low, scale_high = 0.05, 0.25  # min and max size of the squares wrt the full image
        scale = np.random.uniform(scale_low, scale_high)
        transform_list.append(
            transforms.Cutout(num_holes=8,
                              max_h_size=int(scale * ly),
                              max_w_size=int(scale * lx),
                              p=params['cutout']))

    # Compose all image transformations and augment the image
    augmentation_fn = albumentations.Compose(transform_list)
    im = augmentation_fn(image=im)['image']

    return im
Ejemplo n.º 16
0
def augment(im, params=None):
    """
    Perform data augmentation on some image using the albumentations package.

    Parameters
    ----------
    im : Numpy array
    params : dict or None
        Contains the data augmentation parameters
        Mandatory keys:
        - h_flip ([0,1] float): probability of performing an horizontal left-right mirroring.
        - v_flip ([0,1] float): probability of performing an vertical up-down mirroring.
        - rot ([0,1] float):  probability of performing a rotation to the image.
        - rot_lim (int):  max degrees of rotation.
        - stretch ([0,1] float):  probability of randomly stretching an image.
        - expand ([True, False] bool): whether to pad the image to a square shape with background color canvas.
        - crop ([0,1] float): randomly take an image crop.
        - invert_col ([0, 1] float): randomly invert the colors of the image. p=1 -> invert colors (VPR)
        - zoom ([0,1] float): random zoom applied to crop_size.
            --> Therefore the effective crop size at each iteration will be a
                random number between 1 and crop*(1-zoom). For example:
                  * crop=1, zoom=0: no crop of the image
                  * crop=1, zoom=0.1: random crop of random size between 100% image and 90% of the image
                  * crop=0.9, zoom=0.1: random crop of random size between 90% image and 80% of the image
                  * crop=0.9, zoom=0: random crop of always 90% of the image
                  Image size refers to the size of the shortest side.
        - blur ([0,1] float):  probability of randomly blurring an image.
        - pixel_noise ([0,1] float):  probability of randomly adding pixel noise to an image.
        - pixel_sat ([0,1] float):  probability of randomly using HueSaturationValue in the image.
        - cutout ([0,1] float):  probability of using cutout in the image.

    Returns
    -------
    Numpy array
    """
    ## 1) Expand the image by padding it with bg-color canvas
    if params["expand"]:
        desired_size = max(im.shape)
        # check bg
        if np.argmax(im.shape) > 0:
            bgcol = tuple(np.repeat(int(np.mean(im[[0, -1], :, :])), 3))
        else:
            bgcol = tuple(np.repeat(int(np.mean(im[:, [0, -1], :])), 3))

        im = Image.fromarray(im)
        old_size = im.size  # old_size[0] is in (width, height) format

        ratio = float(desired_size) / max(old_size)
        new_size = tuple([int(x * ratio) for x in old_size])
        im = im.resize(new_size, Image.ANTIALIAS)
        # create a new image and paste the resized on it
        new_im = Image.new("RGB", (desired_size, desired_size), color=bgcol)
        new_im.paste(im, ((desired_size - new_size[0]) // 2,
                          (desired_size - new_size[1]) // 2))

        im = np.array(new_im)

    ## 2) Crop the image
    if params["crop"] and params["crop"] != 1:
        effective_zoom = np.random.rand() * params['zoom']
        crop = params['crop'] - effective_zoom

        ly, lx, channels = im.shape
        crop_size = int(crop * min([ly, lx]))
        rand_x = np.random.randint(low=0, high=lx - crop_size + 1)
        rand_y = np.random.randint(low=0, high=ly - crop_size + 1)

        crop = transforms.Crop(x_min=rand_x,
                               y_min=rand_y,
                               x_max=rand_x + crop_size,
                               y_max=rand_y + crop_size)

        im = crop(image=im)['image']

    if params["enhance"]:
        im = Image.fromarray(im)
        enhancer = ImageEnhance.Contrast(im)
        im = np.array(enhancer.enhance(params["enhance"]))

    ## 3) Now add the transformations for augmenting the image pixels
    transform_list = []

    if params['invert_col']:
        transform_list.append(transforms.InvertImg(p=params['invert_col']))

    # Add random stretching
    if params['stretch']:
        transform_list.append(
            imgaug_transforms.IAAPerspective(scale=0.1, p=params['stretch']))

    # Add random rotation
    if params['rot']:
        transform_list.append(
            transforms.Rotate(limit=params['rot_lim'], p=params['rot']))

    # Add horizontal flip
    if params['h_flip']:
        transform_list.append(transforms.HorizontalFlip(p=params['h_flip']))

    # Add vertical flip
    if params['v_flip']:
        transform_list.append(transforms.VerticalFlip(p=params['v_flip']))

    # Add some blur to the image
    if params['blur']:
        transform_list.append(
            albumentations.OneOf([
                transforms.MotionBlur(blur_limit=7, p=1.),
                transforms.MedianBlur(blur_limit=7, p=1.),
                transforms.Blur(blur_limit=7, p=1.),
            ],
                                 p=params['blur']))

    # Add pixel noise
    if params['pixel_noise']:
        transform_list.append(
            albumentations.OneOf(
                [
                    transforms.CLAHE(clip_limit=2, p=1.),
                    imgaug_transforms.IAASharpen(p=1.),
                    imgaug_transforms.IAAEmboss(p=1.),
                    transforms.RandomBrightnessContrast(contrast_limit=0,
                                                        p=1.),
                    transforms.RandomBrightnessContrast(brightness_limit=0,
                                                        p=1.),
                    transforms.RGBShift(p=1.),
                    transforms.RandomGamma(p=1.)  #,
                    # transforms.JpegCompression(),
                    # transforms.ChannelShuffle(),
                    # transforms.ToGray()
                ],
                p=params['pixel_noise']))

    # Add pixel saturation
    if params['pixel_sat']:
        transform_list.append(
            transforms.HueSaturationValue(p=params['pixel_sat']))

    # Remove randomly remove some regions from the image
    if params['cutout']:
        ly, lx, channels = im.shape
        scale_low, scale_high = 0.05, 0.25  # min and max size of the squares wrt the full image
        scale = np.random.uniform(scale_low, scale_high)
        transform_list.append(
            transforms.Cutout(num_holes=8,
                              max_h_size=int(scale * ly),
                              max_w_size=int(scale * lx),
                              p=params['cutout']))

    # Compose all image transformations and augment the image
    augmentation_fn = albumentations.Compose(transform_list)
    im = augmentation_fn(image=im)['image']

    return im
Ejemplo n.º 17
0
alb_transforms = [
    alb.IAAAdditiveGaussianNoise(p=1),
    alb.GaussNoise(p=1),
    alb.MotionBlur(p=1),
    alb.MedianBlur(blur_limit=3, p=1),
    alb.Blur(blur_limit=3, p=1),
    alb.OpticalDistortion(p=1),
    alb.GridDistortion(p=1),
    alb.IAAPiecewiseAffine(p=1),
    aat.CLAHE(clip_limit=2, p=1),
    alb.IAASharpen(p=1),
    alb.IAAEmboss(p=1),
    aat.HueSaturationValue(p=0.3),
    aat.HorizontalFlip(p=1),
    aat.RGBShift(),
    aat.RandomBrightnessContrast(),
    aat.RandomGamma(p=1),
    aat.Cutout(2, 10, 10, p=1),
    aat.Equalize(mode='cv', p=1),
    aat.FancyPCA(p=1),
    aat.RandomFog(p=1),
    aat.RandomRain(blur_value=3, p=1),
    albumentations.IAAAffine(p=1),
    albumentations.ShiftScaleRotate(rotate_limit=15, p=1)
]


def one_by_one():
    data_dir = 'train_val/pic'
    data_anno = pd.read_csv('train_val/keys.csv')
    orig = MyDataset(data_dir, data_anno)
Ejemplo n.º 18
0
    def __init__(
        self,
        prob=0,
        Flip_prob=0,
        HueSaturationValue_prob=0,
        RandomBrightnessContrast_prob=0,
        crop_prob=0,
        randomrotate90_prob=0,
        elastictransform_prob=0,
        gridistortion_prob=0,
        opticaldistortion_prob=0,
        verticalflip_prob=0,
        horizontalflip_prob=0,
        randomgamma_prob=0,
        CoarseDropout_prob=0,
        RGBShift_prob=0,
        MotionBlur_prob=0,
        MedianBlur_prob=0,
        GaussianBlur_prob=0,
        GaussNoise_prob=0,
        ChannelShuffle_prob=0,
        ColorJitter_prob=0,
    ):
        super().__init__()

        self.prob = prob
        self.randomrotate90_prob = randomrotate90_prob
        self.elastictransform_prob = elastictransform_prob

        self.transforms = al.Compose(
            [
                transforms.RandomRotate90(p=randomrotate90_prob),
                transforms.Flip(p=Flip_prob),
                transforms.HueSaturationValue(p=HueSaturationValue_prob),
                transforms.RandomBrightnessContrast(
                    p=RandomBrightnessContrast_prob),
                transforms.Transpose(),
                OneOf(
                    [
                        transforms.RandomCrop(220, 220, p=0.5),
                        transforms.CenterCrop(220, 220, p=0.5),
                    ],
                    p=crop_prob,
                ),
                ElasticTransform(
                    p=elastictransform_prob,
                    alpha=120,
                    sigma=120 * 0.05,
                    alpha_affine=120 * 0.03,
                ),
                GridDistortion(p=gridistortion_prob),
                OpticalDistortion(p=opticaldistortion_prob,
                                  distort_limit=2,
                                  shift_limit=0.5),
                VerticalFlip(p=verticalflip_prob),
                HorizontalFlip(p=horizontalflip_prob),
                RandomGamma(p=randomgamma_prob),
                RGBShift(p=RGBShift_prob),
                MotionBlur(p=MotionBlur_prob, blur_limit=7),
                MedianBlur(p=MedianBlur_prob, blur_limit=9),
                GaussianBlur(p=GaussianBlur_prob, blur_limit=9),
                GaussNoise(p=GaussNoise_prob),
                ChannelShuffle(p=ChannelShuffle_prob),
                CoarseDropout(p=CoarseDropout_prob,
                              max_holes=8,
                              max_height=32,
                              max_width=32),
                ColorJitter(p=ColorJitter_prob)
                # transforms.Resize(352, 352),
                # transforms.Normalize(
                #     mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
                # ),
            ],
            p=self.prob,
        )
Ejemplo n.º 19
0
    def get_transforms(phase: str, cli_args) -> Dict[str, Compose]:
        """Get composed albumentations augmentations

        Parameters
        ----------
        phase : str
            Phase of learning
            In ['train', 'val']
        cli_args
            Arguments coming all the way from `main.py`

        Returns
        -------
        transforms: dict[str, albumentations.core.composition.Compose]
            Composed list of transforms
        """
        aug_transforms = []
        im_sz = (cli_args.image_size, cli_args.image_size)

        if phase == "train":
            # Data augmentation for training only
            aug_transforms.extend([
                tf.ShiftScaleRotate(
                    shift_limit=0,
                    scale_limit=0.1,
                    rotate_limit=15,
                    p=0.5),
                tf.Flip(p=0.5),
                tf.RandomRotate90(p=0.5),
            ])
            # Exotic Augmentations for train only 🤤
            aug_transforms.extend([
                tf.RandomBrightnessContrast(p=0.5),
                tf.ElasticTransform(p=0.5),
                tf.MultiplicativeNoise(multiplier=(0.5, 1.5),
                                       per_channel=True, p=0.2),
            ])
        aug_transforms.extend([
            tf.RandomSizedCrop(min_max_height=im_sz,
                               height=im_sz[0],
                               width=im_sz[1],
                               w2h_ratio=1.0,
                               interpolation=cv2.INTER_LINEAR,
                               p=1.0),
        ])
        aug_transforms = Compose(aug_transforms)

        mask_only_transforms = Compose([
            tf.Normalize(mean=0, std=1, always_apply=True)
        ])
        image_only_transforms = Compose([
            tf.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0),
                         always_apply=True)
        ])
        final_transforms = Compose([
            ToTensorV2()
        ])

        transforms = {
            'aug': aug_transforms,
            'img_only': image_only_transforms,
            'mask_only': mask_only_transforms,
            'final': final_transforms
        }
        return transforms
Ejemplo n.º 20
0
 def __init__(self) -> None:
     """
     Initializes the random brightener/darkener
     """
     self.darken_or_brighten_object = albu.RandomBrightnessContrast(
         (-1.0, 1.0), (0.0, 0.0), always_apply=True)
Ejemplo n.º 21
0
def main():
    args = vars(parse_args_func())

    #config_file = "../configs/config_SN7.json"
    config_file = args['config']  # "../configs/config_v1.json"
    config_dict = json.loads(open(config_file, 'rt').read())
    #config_dict = json.loads(open(sys.argv[1], 'rt').read())

    file_dict = config_dict['file_path']
    config = config_dict['opt_config']

    input_folder = file_dict['input_path']  # '../inputs'
    checkpoint_folder = file_dict['checkpoint_path']  # '../checkpoint'
    model_folder = file_dict['model_path']  # '../models'

    if 'False' in config['deep_supervision']:
        config['deep_supervision'] = False
    else:
        config['deep_supervision'] = True

    if 'False' in config['nesterov']:
        config['nesterov'] = False
    else:
        config['nesterov'] = True

    if 'None' in config['name']:
        config['name'] = None

    if config['name'] is None:
        config['name'] = '%s_%s_segmodel' % (config['dataset'], config['arch'])
    os.makedirs(os.path.join(model_folder, '%s' % config['name']),
                exist_ok=True)

    if not os.path.isdir(checkpoint_folder):
        os.mkdir(checkpoint_folder)
    log_name = config['name']
    log_dir = os.path.join(checkpoint_folder, log_name)
    writer = SummaryWriter(logdir=log_dir)

    print('-' * 20)
    for key in config:
        print('%s: %s' % (key, config[key]))
    print('-' * 20)

    with open(os.path.join(model_folder, '%s/config.yml' % config['name']),
              'w') as f:
        yaml.dump(config, f)

    # define loss function (criterion)
    if config['loss'] == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion = losses.__dict__[config['loss']]().cuda()

    cudnn.benchmark = True

    # create model
    print("=> creating model %s" % config['arch'])
    model = archs.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])

    if 'False' in config['resume']:
        config['resume'] = False
    else:
        config['resume'] = True
    resume_flag = False
    if resume_flag == True:
        save_path = os.path.join(model_folder, config['name'], 'model.pth')
        weights = torch.load(save_path)
        model.load_state_dict(weights)
        name_yaml = config['name']
        with open(os.path.join(model_folder, '%s/config.yml' % name_yaml),
                  'r') as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        #start_epoch = config['epochs']
        start_epoch = 0
    else:
        start_epoch = 0

    model = model.cuda()
    if 'effnet' in config['arch']:
        eff_flag = True
    else:
        eff_flag = False

    if eff_flag == True:
        cnn_subs = list(model.encoder.eff_conv.children())[1:]
        #cnn_params = [list(sub_module.parameters()) for sub_module in cnn_subs]
        #cnn_params = [item for sublist in cnn_params for item in sublist]

    summary(model,
            (config['input_channels'], config['input_w'], config['input_h']))
    params = filter(lambda p: p.requires_grad, model.parameters())
    if eff_flag == True:
        params = list(params) + list(model.encoder.conv_a.parameters())
    model = torch.nn.DataParallel(model)

    if config['optimizer'] == 'Adam':
        optimizer = optim.Adam(params,
                               lr=config['lr'],
                               weight_decay=config['weight_decay'])
    elif config['optimizer'] == 'SGD':
        optimizer = optim.SGD(params,
                              lr=config['lr'],
                              momentum=config['momentum'],
                              nesterov=config['nesterov'],
                              weight_decay=config['weight_decay'])
    else:
        raise NotImplementedError

    if eff_flag == True:
        cnn_params = [list(sub_module.parameters()) for sub_module in cnn_subs]
        cnn_params = [item for sublist in cnn_params for item in sublist]
        cnn_optimizer = torch.optim.Adam(cnn_params,
                                         lr=0.001,
                                         weight_decay=config['weight_decay'])
        #cnn_optimizer = None

    else:
        cnn_optimizer = None
    if config['optimizer'] == 'SGD':
        if config['scheduler'] == 'CosineAnnealingLR':
            scheduler = lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
        elif config['scheduler'] == 'ReduceLROnPlateau':
            scheduler = lr_scheduler.ReduceLROnPlateau(
                optimizer,
                factor=config['factor'],
                patience=config['patience'],
                verbose=1,
                min_lr=config['min_lr'])
        elif config['scheduler'] == 'MultiStepLR':
            scheduler = lr_scheduler.MultiStepLR(
                optimizer,
                milestones=[int(e) for e in config['milestones'].split(',')],
                gamma=config['gamma'])
        elif config['scheduler'] == 'ConstantLR':
            scheduler = None
        else:
            raise NotImplementedError
    else:
        scheduler = None

    # Data loading code
    img_ids = glob(
        os.path.join(input_folder, config['dataset'], 'images', 'training',
                     '*' + config['img_ext']))
    train_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    #img_dir = os.path.join(input_folder, config['dataset'], 'images', 'training')
    #mask_dir = os.path.join(input_folder, config['dataset'], 'annotations', 'training')
    #train_image_mask = image_to_afile(img_dir, mask_dir, None, train_img_ids, config)

    img_ids = glob(
        os.path.join(input_folder, config['val_dataset'], 'images',
                     'validation', '*' + config['img_ext']))
    val_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    img_ids = glob(
        os.path.join(input_folder, config['val_dataset'], 'images', 'test',
                     '*' + config['img_ext']))
    test_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    train_transform = Compose([
        #transforms.RandomScale ([config['scale_min'], config['scale_max']]),
        #transforms.RandomRotate90(),
        transforms.Rotate([config['rotate_min'], config['rotate_max']],
                          value=mean,
                          mask_value=0),
        transforms.Flip(),
        #transforms.HorizontalFlip (),
        transforms.HueSaturationValue(hue_shift_limit=10,
                                      sat_shift_limit=10,
                                      val_shift_limit=10),
        transforms.RandomBrightnessContrast(brightness_limit=0.10,
                                            contrast_limit=0.10,
                                            brightness_by_max=True),
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(mean=mean, std=std),
    ])

    val_transform = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(mean=mean, std=std),
    ])

    train_dataset = Dataset(img_ids=train_img_ids,
                            img_dir=os.path.join(input_folder,
                                                 config['dataset'], 'images',
                                                 'training'),
                            mask_dir=os.path.join(input_folder,
                                                  config['dataset'],
                                                  'annotations', 'training'),
                            img_ext=config['img_ext'],
                            mask_ext=config['mask_ext'],
                            num_classes=config['num_classes'],
                            input_channels=config['input_channels'],
                            transform=train_transform,
                            from_file=None)
    val_dataset = Dataset(img_ids=val_img_ids,
                          img_dir=os.path.join(input_folder,
                                               config['val_dataset'], 'images',
                                               'validation'),
                          mask_dir=os.path.join(input_folder,
                                                config['val_dataset'],
                                                'annotations', 'validation'),
                          img_ext=config['img_ext'],
                          mask_ext=config['mask_ext'],
                          num_classes=config['num_classes'],
                          input_channels=config['input_channels'],
                          transform=val_transform,
                          from_file=None)
    test_dataset = Dataset(img_ids=test_img_ids,
                           img_dir=os.path.join(input_folder,
                                                config['val_dataset'],
                                                'images', 'test'),
                           mask_dir=os.path.join(input_folder,
                                                 config['val_dataset'],
                                                 'annotations', 'test'),
                           img_ext=config['img_ext'],
                           mask_ext=config['mask_ext'],
                           num_classes=config['num_classes'],
                           input_channels=config['input_channels'],
                           transform=val_transform,
                           from_file=None)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,  #config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=1,  #config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False)

    log = OrderedDict([
        ('epoch', []),
        ('lr', []),
        ('loss', []),
        ('iou', []),
        ('dice', []),
        ('val_loss', []),
        ('val_iou', []),
        ('val_dice', []),
    ])

    best_iou = 0
    trigger = 0
    Best_dice = 0
    iou_AtBestDice = 0
    for epoch in range(start_epoch, config['epochs']):
        print('{:s} Epoch [{:d}/{:d}]'.format(config['arch'], epoch,
                                              config['epochs']))
        # train for one epoch
        train_log = train(epoch, config, train_loader, model, criterion,
                          optimizer, cnn_optimizer)
        if config['optimizer'] == 'SGD':
            if config['scheduler'] == 'CosineAnnealingLR':
                scheduler.step()
            elif config['scheduler'] == 'ReduceLROnPlateau':
                scheduler.step(val_log['loss'])
            elif config['scheduler'] == 'MultiStepLR':
                scheduler.step()

        # evaluate on validation set
        val_log = validate(config, val_loader, model, criterion)
        test_log = validate(config, test_loader, model, criterion)

        if Best_dice < test_log['dice']:
            Best_dice = test_log['dice']
            iou_AtBestDice = test_log['iou']
        print(
            'loss %.4f - iou %.4f - dice %.4f - val_loss %.4f - val_iou %.4f - val_dice %.4f - test_iou %.4f - test_dice %.4f - Best_dice %.4f - iou_AtBestDice %.4f'
            % (train_log['loss'], train_log['iou'], train_log['dice'],
               val_log['loss'], val_log['iou'], val_log['dice'],
               test_log['iou'], test_log['dice'], Best_dice, iou_AtBestDice))

        save_tensorboard(writer, train_log, val_log, test_log, epoch)
        log['epoch'].append(epoch)
        log['lr'].append(config['lr'])
        log['loss'].append(train_log['loss'])
        log['iou'].append(train_log['iou'])
        log['dice'].append(train_log['dice'])
        log['val_loss'].append(val_log['loss'])
        log['val_iou'].append(val_log['iou'])
        log['val_dice'].append(val_log['dice'])

        pd.DataFrame(log).to_csv(os.path.join(model_folder,
                                              '%s/log.csv' % config['name']),
                                 index=False)

        trigger += 1

        if val_log['iou'] > best_iou:
            torch.save(
                model.state_dict(),
                os.path.join(model_folder, '%s/model.pth' % config['name']))
            best_iou = val_log['iou']
            print("=> saved best model")
            trigger = 0

        # early stopping
        if config['early_stopping'] >= 0 and trigger >= config[
                'early_stopping']:
            print("=> early stopping")
            break

        torch.cuda.empty_cache()