コード例 #1
0
def get_transforms_fer_aug( size_input ):
    return transforms.Compose([

        #------------------------------------------------------------------
        #Resize input
        mtrans.ToResize( (48, 48), resize_mode='square', padding_mode=cv2.BORDER_REFLECT ),

        #------------------------------------------------------------------
        #Geometric
        mtrans.RandomScale(factor=0.2, padding_mode=cv2.BORDER_REPLICATE ),
        mtrans.ToRandomTransform( mtrans.RandomGeometricalTransform( angle=30, translation=0.2, warp=0.02, padding_mode=cv2.BORDER_REFLECT ), prob=0.5 ),
        mtrans.ToRandomTransform( mtrans.VFlip(), prob=0.5 ),
        #mtrans.ToRandomTransform( mtrans.HFlip(), prob=0.5 ),

        #------------------------------------------------------------------
        #Colors
        mtrans.ToRandomTransform( mtrans.RandomBrightness( factor=0.25 ), prob=0.50 ),
        mtrans.ToRandomTransform( mtrans.RandomContrast( factor=0.25 ), prob=0.50 ),
        mtrans.ToRandomTransform( mtrans.RandomGamma( factor=0.25 ), prob=0.50 ),
        mtrans.ToRandomTransform( mtrans.RandomRGBPermutation(), prob=0.50 ),
        mtrans.ToRandomTransform( mtrans.CLAHE(), prob=0.25 ),
        mtrans.ToRandomTransform( mtrans.ToGaussianBlur( sigma=0.05 ), prob=0.25 ),

        #------------------------------------------------------------------
        #Resize
        mtrans.ToResize( (size_input+5, size_input+5), resize_mode='square', padding_mode=cv2.BORDER_REFLECT ),
        mtrans.RandomCrop( (size_input, size_input), limit=2, padding_mode=cv2.BORDER_REFLECT  ),

        #------------------------------------------------------------------
        mtrans.ToGrayscale(),
        mtrans.ToTensor(),
        normalize,

        ])
コード例 #2
0
def get_transforms_aug2(size_input):
    transforms_aug = transforms.Compose([
        mtrans.ToResize((size_input, size_input),
                        resize_mode='square',
                        padding_mode=cv2.BORDER_REFLECT_101),
        #mtrans.ToResize( (size_input+20, size_input+20), resize_mode='asp' ) ,
        #mtrans.RandomCrop( (size_input, size_input), limit=0, padding_mode=cv2.BORDER_REFLECT_101  ) ,
        mtrans.RandomScale(factor=0.2, padding_mode=cv2.BORDER_REFLECT_101),
        mtrans.RandomGeometricalTransform(angle=45,
                                          translation=0.2,
                                          warp=0.02,
                                          padding_mode=cv2.BORDER_REFLECT_101),
        mtrans.ToRandomTransform(mtrans.HFlip(), prob=0.5),
        #------------------------------------------------------------------
        #mtrans.RandomRGBPermutation(),
        #mtrans.ToRandomChoiceTransform( [
        #    mtrans.RandomBrightness( factor=0.15 ),
        #    mtrans.RandomContrast( factor=0.15 ),
        #    #mtrans.RandomSaturation( factor=0.15 ),
        #    mtrans.RandomHueSaturation( hue_shift_limit=(-5, 5), sat_shift_limit=(-11, 11), val_shift_limit=(-11, 11) ),
        #    mtrans.RandomGamma( factor=0.30  ),
        #    mtrans.ToRandomTransform(mtrans.ToGrayscale(), prob=0.15 ),
        #    ]),
        #mtrans.ToRandomTransform(mtrans.ToGaussianBlur( sigma=0.0001), prob=0.15 ),

        #------------------------------------------------------------------
        mtrans.ToTensor(),
        normalize,
    ])
    return transforms_aug
コード例 #3
0
def get_transforms_fer_det(size_input):
    return transforms.Compose([
        mtrans.ToResize( (size_input, size_input), resize_mode='square', padding_mode=cv2.BORDER_REFLECT ),
        mtrans.ToGrayscale(),
        mtrans.ToTensor(),
        normalize,
        ])
コード例 #4
0
ファイル: aug.py プロジェクト: yeyuhang/ferattention
def get_transforms_aug( size_input ):
    return transforms.Compose([        
        
        #------------------------------------------------------------------
        #Resize 
        mtrans.ToResize( (size_input,size_input), resize_mode='square', padding_mode=cv2.BORDER_REPLICATE),  
        
        #------------------------------------------------------------------
        #Colors           
        mtrans.ToRandomTransform( mtrans.RandomBrightness( factor=0.25 ), prob=0.750 ),
        mtrans.ToRandomTransform( mtrans.RandomContrast( factor=0.25 ), prob=0.750 ),
        mtrans.ToRandomTransform( mtrans.RandomGamma( factor=0.25 ), prob=0.750 ),
#         mtrans.ToRandomTransform( mtrans.RandomRGBPermutation(), prob=0.50 ),
#         mtrans.ToRandomTransform( mtrans.CLAHE(), prob=0.25 ),
        mtrans.ToRandomTransform(mtrans.ToGaussianBlur( sigma=0.05 ), prob=0.25 ),
#         mtrans.ToRandomTransform(mtrans.ToGaussianNoise( sigma=0.05 ), prob=0.25 ),
        
        
#         mtrans.ToRandomTransform( mtrans.RandomBrightness( factor=0.25 ), prob=0.50 ), 
#         mtrans.ToRandomTransform( mtrans.RandomContrast( factor=0.25 ), prob=0.50 ), 
#         mtrans.ToRandomTransform( mtrans.RandomGamma( factor=0.25 ), prob=0.50 ), 
#         mtrans.ToRandomTransform( mtrans.RandomRGBPermutation(), prob=0.50 ), 
#         mtrans.ToRandomTransform( mtrans.CLAHE(), prob=0.25 ), 
#         mtrans.ToRandomTransform( mtrans.ToGaussianBlur( sigma=0.05 ), prob=0.25 ), 
        
        
        #------------------------------------------------------------------
        mtrans.ToTensor(),
        normalize,
        
        ])    
コード例 #5
0
def get_transforms_det(size_input):    
    transforms_det = transforms.Compose([
        #mtrans.ToResize( (size_input, size_input), resize_mode='crop' ) ,
        mtrans.ToResize( (size_input, size_input), resize_mode='square', padding_mode=cv2.BORDER_REPLICATE ) ,
        mtrans.ToTensor(),
        normalize,
        ])
    return transforms_det
コード例 #6
0
def get_transforms_test(size_input=256):
    return transforms.Compose([
        #mtrans.RandomCrop( (size_crop, size_crop), limit=10, padding_mode=cv2.BORDER_REFLECT_101  ),
        mtrans.ToResize((size_input, size_input),
                        resize_mode='square',
                        padding_mode=cv2.BORDER_REFLECT_101),
        #mtrans.ToPad( 5 , 5, padding_mode=cv2.BORDER_REFLECT_101 ),
        mtrans.ToTensor(),
        normalize,
    ])
コード例 #7
0
def transform_aug():
    return transforms.Compose([

        #mtrans.HFlip(),
        #mtrans.VFlip(),
        #mtrans.Rotate90(),
        #mtrans.Rotate180(),
        #mtrans.Rotate270(),

        ## resize and crop
        mtrans.ToResize((600, 600),
                        resize_mode='asp',
                        padding_mode=cv2.BORDER_CONSTANT),
        #mtrans.ToPad( 20, 20, padding_mode=cv2.BORDER_CONSTANT ) ,
        #mtrans.CenterCrop( (200,200) ),
        #mtrans.RandomCrop( (400,400), limit=50, padding_mode=cv2.BORDER_REFLECT_101  ),
        #mtrans.ToResizeUNetFoV(388, cv2.BORDER_REFLECT_101),

        # ## color
        # mtrans.ToRandomChoiceTransform( [
        # mtrans.RandomSaturation(),
        # mtrans.RandomHueSaturationShift(),
        # mtrans.RandomHueSaturation(),
        # mtrans.RandomRGBShift(),
        # #mtrans.ToNegative(),
        # mtrans.RandomRGBPermutation(),
        # mtrans.ToRandomTransform( mtrans.ToGrayscale(), prob=0.5 ),
        # #mtrans.ToGrayscale(),
        # ]),

        ## blur
        # mtrans.ToRandomTransform( mtrans.ToLinealMotionBlur( lmax=1 ), prob=0.5 ),
        mtrans.ToRandomTransform(mtrans.ToMotionBlur(), prob=1.0),
        # mtrans.ToRandomTransform( mtrans.ToGaussianBlur(), prob=0.75 ),

        ## geometrical
        #mtrans.ToRandomTransform( mtrans.HFlip(), prob=0.5 ),
        #mtrans.ToRandomTransform( mtrans.VFlip(), prob=0.5 ),
        #mtrans.RandomScale(factor=0.2, padding_mode=cv2.BORDER_REFLECT101 ),
        #mtrans.RandomGeometricalTransform( angle=360, translation=0.2, warp=0.02, padding_mode=cv2.BORDER_REFLECT101),
        #mtrans.RandomElasticDistort( size_grid=50, padding_mode=cv2.BORDER_REFLECT101 ),

        ## tensor
        mtrans.ToTensor(),
        # mtrans.RandomElasticTensorDistort( size_grid=10, deform=0.05 ),

        ## normalization
        mtrans.ToNormalization(),
        #mtrans.ToWhiteNormalization(),
        #mtrans.ToMeanNormalization(
        #    mean=[0.485, 0.456, 0.406],
        #    std=[0.229, 0.224, 0.225]
        #    ),
    ])
コード例 #8
0
def get_transforms_det(size_input):
    return transforms.Compose([
        #mtrans.ToResize( (38, 38), resize_mode='squash', padding_mode=cv2.BORDER_REPLICATE ),
        #mtrans.ToPad( 5 , 5, padding_mode=cv2.BORDER_REPLICATE ) ,
        #mtrans.ToResize( (size_input+20, size_input+20), resize_mode='squash', padding_mode=cv2.BORDER_REPLICATE ),
        #mtrans.CenterCrop( (size_input, size_input), padding_mode=cv2.BORDER_REPLICATE  ) ,
        mtrans.ToResize( (size_input, size_input), resize_mode='squash' ) ,
        #mtrans.ToResize( (size_input, size_input), resize_mode='square', padding_mode=cv2.BORDER_REPLICATE ) ,
        mtrans.ToGrayscale(),
        mtrans.ToTensor(),
        normalize,
        ])
コード例 #9
0
def get_transforms_det(size_input):
    return transforms.Compose([
        #mtrans.ToResize( (38, 38), resize_mode='squash', padding_mode=cv2.BORDER_REPLICATE ),
        #mtrans.ToPad( 5 , 5, padding_mode=cv2.BORDER_REPLICATE ) ,
        #mtrans.ToResize( (size_input+20, size_input+20), resize_mode='squash', padding_mode=cv2.BORDER_REPLICATE ),
        #mtrans.CenterCrop( (size_input, size_input), padding_mode=cv2.BORDER_REPLICATE  ) ,
        #mtrans.ToResize( (128, 128), resize_mode='squash' ) ,

        #mtrans.RandomCrop( (size_input, size_input), limit=2, padding_mode=cv2.BORDER_REFLECT_101  ) ,
        mtrans.ToResize((size_input, size_input),
                        resize_mode='square',
                        padding_mode=cv2.BORDER_REFLECT_101),
        mtrans.ToTensor(),
        normalize,
    ])
コード例 #10
0
def get_transforms_aug(size_input=256, size_crop=512):
    return transforms.Compose([

        #------------------------------------------------------------------
        #Resize
        #
        mtrans.RandomCrop((size_crop, size_crop),
                          limit=10,
                          padding_mode=cv2.BORDER_REFLECT_101),
        mtrans.ToResize((size_input, size_input),
                        resize_mode='square',
                        padding_mode=cv2.BORDER_REFLECT_101),
        #mtrans.ToPad( 5 , 5, padding_mode=cv2.BORDER_REFLECT_101 ),

        #------------------------------------------------------------------
        #Geometric
        mtrans.ToRandomTransform(mtrans.VFlip(), prob=0.5),
        mtrans.ToRandomTransform(mtrans.HFlip(), prob=0.5),
        mtrans.RandomScale(factor=0.3, padding_mode=cv2.BORDER_REFLECT_101),
        mtrans.ToRandomTransform(mtrans.RandomGeometricalTransform(
            angle=45,
            translation=0.2,
            warp=0.02,
            padding_mode=cv2.BORDER_REFLECT_101),
                                 prob=0.5),
        mtrans.ToRandomTransform(mtrans.RandomElasticDistort(
            size_grid=32, deform=12, padding_mode=cv2.BORDER_REFLECT_101),
                                 prob=0.5),
        #mtrans.ToResizeUNetFoV(imsize, cv2.BORDER_REFLECT_101),

        #------------------------------------------------------------------
        #Colors
        mtrans.ToRandomTransform(mtrans.RandomBrightness(factor=0.25),
                                 prob=0.50),
        mtrans.ToRandomTransform(mtrans.RandomContrast(factor=0.25),
                                 prob=0.50),
        mtrans.ToRandomTransform(mtrans.RandomGamma(factor=0.25), prob=0.50),
        #mtrans.ToRandomTransform( mtrans.RandomRGBPermutation(), prob=0.50 ),
        #mtrans.ToRandomTransform( mtrans.CLAHE(), prob=0.25 ),
        #mtrans.ToRandomTransform(mtrans.ToGaussianBlur( sigma=0.05 ), prob=0.25 ),

        #------------------------------------------------------------------
        mtrans.ToTensor(),
        normalize,
    ])
コード例 #11
0
def get_transforms_aug(size_input):
    return transforms.Compose([

        #------------------------------------------------------------------
        #Resize
        mtrans.ToResize((size_input, size_input),
                        resize_mode='square',
                        padding_mode=cv2.BORDER_REFLECT_101),
        #mtrans.ToResize( (size_input+10, size_input+10), resize_mode='square', padding_mode=cv2.BORDER_REFLECT_101 ) ,
        #mtrans.RandomCrop( (size_input, size_input), limit=2, padding_mode=cv2.BORDER_REFLECT_101  ) ,

        #------------------------------------------------------------------
        #Geometric
        mtrans.RandomScale(factor=0.2, padding_mode=cv2.BORDER_REFLECT_101),
        mtrans.ToRandomTransform(mtrans.RandomGeometricalTransform(
            angle=45,
            translation=0.2,
            warp=0.02,
            padding_mode=cv2.BORDER_REFLECT_101),
                                 prob=0.5),
        mtrans.ToRandomTransform(mtrans.VFlip(), prob=0.5),
        mtrans.ToRandomTransform(mtrans.HFlip(), prob=0.5),
        #mtrans.ToRandomTransform( mtrans.RandomElasticDistort( size_grid=32, deform=12, padding_mode=cv2.BORDER_REFLECT_101 ), prob=0.5 ),

        #------------------------------------------------------------------
        #Colors
        mtrans.ToRandomTransform(mtrans.RandomBrightness(factor=0.25),
                                 prob=0.50),
        mtrans.ToRandomTransform(mtrans.RandomContrast(factor=0.25),
                                 prob=0.50),
        mtrans.ToRandomTransform(mtrans.RandomGamma(factor=0.25), prob=0.50),
        #mtrans.ToRandomTransform( mtrans.RandomHueSaturation( hue_shift_limit=(-5, 5), sat_shift_limit=(-11, 11), val_shift_limit=(-11, 11) ), prob=0.30 ),
        #mtrans.ToRandomTransform( mtrans.RandomRGBPermutation(), prob=0.50 ),
        #mtrans.ToRandomTransform( mtrans.CLAHE(), prob=0.25 ),
        #mtrans.ToRandomTransform(mtrans.ToGaussianBlur( sigma=0.05 ), prob=0.25 ),

        #------------------------------------------------------------------
        mtrans.ToTensor(),
        normalize,
    ])
コード例 #12
0
    gpu=0
    seed=1
    imsize=101


    # Load dataset
    print('>> Load dataset ...')

    dataset  = TGSDataset(  
        pathnamedataset, 
        'test', 
        num_channels=3,
        train=False, 
        files='sample_submission.csv',
        transform=transforms.Compose([
            mtrans.ToResize( (256,256), resize_mode='squash', padding_mode=cv2.BORDER_REFLECT_101 ),
            #mtrans.ToResizeUNetFoV(imsize, cv2.BORDER_REFLECT_101),
            mtrans.ToTensor(),
            #mtrans.ToNormalization(), 
            mtrans.ToMeanNormalization( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], )
            ])
        )

    # load model
    print('>> Load model ...')

    net = SegmentationNeuralNet( 
        patchproject=project, 
        nameproject=projectname, 
        no_cuda=cuda, 
        parallel=parallel, 
コード例 #13
0
def main():

    # parameters
    parser = arg_parser()
    args = parser.parse_args()
    imsize = args.image_size
    parallel = args.parallel
    num_classes = 2
    num_channels = 3
    view_freq = 2

    network = SegmentationNeuralNet(
        patchproject=args.project,
        nameproject=args.name,
        no_cuda=args.no_cuda,
        parallel=parallel,
        seed=args.seed,
        print_freq=args.print_freq,
        gpu=args.gpu,
        view_freq=view_freq,
    )

    network.create(arch=args.arch,
                   num_output_channels=num_classes,
                   num_input_channels=num_channels,
                   loss=args.loss,
                   lr=args.lr,
                   momentum=args.momentum,
                   optimizer=args.opt,
                   lrsch=args.scheduler,
                   pretrained=args.finetuning,
                   size_input=imsize)

    # resume
    network.resume(os.path.join(network.pathmodels, args.resume))
    cudnn.benchmark = True

    # datasets
    # training dataset
    train_data = tgsdata.TGSDataset(
        args.data,
        tgsdata.train,
        count=16000,
        num_channels=num_channels,
        transform=transforms.Compose([
            mtrans.ToRandomTransform(mtrans.HFlip(), prob=0.5),
            mtrans.ToRandomTransform(mtrans.VFlip(), prob=0.5),
            mtrans.ToResize((300, 300),
                            resize_mode='squash',
                            padding_mode=cv2.BORDER_REFLECT_101),
            mtrans.RandomCrop((256, 256),
                              limit=10,
                              padding_mode=cv2.BORDER_REFLECT_101),
            mtrans.RandomScale(factor=0.2,
                               padding_mode=cv2.BORDER_REFLECT_101),
            mtrans.RandomGeometricalTransform(
                angle=30,
                translation=0.2,
                warp=0.02,
                padding_mode=cv2.BORDER_REFLECT_101),
            #mtrans.ToResizeUNetFoV(imsize, cv2.BORDER_REFLECT_101),
            mtrans.ToRandomTransform(mtrans.RandomBrightness(factor=0.15),
                                     prob=0.50),
            mtrans.ToRandomTransform(mtrans.RandomContrast(factor=0.15),
                                     prob=0.50),
            mtrans.ToRandomTransform(mtrans.RandomGamma(factor=0.15),
                                     prob=0.50),
            mtrans.ToRandomTransform(mtrans.ToGaussianBlur(), prob=0.15),
            mtrans.ToTensor(),
            mtrans.ToMeanNormalization(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            )
            #mtrans.ToNormalization(),
        ]))

    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=network.cuda,
                              drop_last=True)

    # validate dataset
    val_data = tgsdata.TGSDataset(
        args.data,
        tgsdata.test,
        count=4000,
        num_channels=num_channels,
        transform=transforms.Compose([
            mtrans.ToResize((256, 256), resize_mode='squash'),
            #mtrans.RandomCrop( (255,255), limit=50, padding_mode=cv2.BORDER_CONSTANT  ),
            #mtrans.ToResizeUNetFoV(imsize, cv2.BORDER_REFLECT_101),
            mtrans.ToTensor(),
            mtrans.ToMeanNormalization(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            )
            #mtrans.ToNormalization(),
        ]))

    val_loader = DataLoader(val_data,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=network.cuda,
                            drop_last=True)

    # print neural net class
    print('SEG-Torch: {}'.format(datetime.datetime.now()))
    print(network)

    # training neural net
    network.fit(train_loader, val_loader, args.epochs, args.snapshot)

    print("Optimization Finished!")
    print("DONE!!!")
コード例 #14
0
from pytvision.datasets.syntheticdata import SyntethicCircleDataset
from pytvision.transforms import transforms as mtrans
from pytvision import visualization as view


data = SyntethicCircleDataset(
        count=100,
        generate=SyntethicCircleDataset.generate_image_mask_and_weight,
        imsize=(512,512),
        sigma=0.01,
        bdraw_grid=True,
        transform=transforms.Compose([

              ## resize and crop
                           
              mtrans.ToResize( (400,400), resize_mode='asp' ) ,
              #mtrans.CenterCrop( (200,200) ),
              mtrans.RandomCrop( (255,255), limit=50, padding_mode=cv2.BORDER_REFLECT_101  ),
              #mtrans.ToResizeUNetFoV(388, cv2.BORDER_REFLECT_101),
              
              ## color 

              #mtrans.RandomSaturation(),
              #mtrans.RandomHueSaturationShift(),
              #mtrans.RandomHueSaturation(),
              #mtrans.RandomRGBShift(),
              #mtrans.ToNegative(),
              #mtrans.RandomRGBPermutation(),
              #mtrans.ToGrayscale(),

              ## blur
コード例 #15
0
from pytvision.transforms import transforms as mtrans
from pytvision import visualization as view

pathdataset = '/home/pdmf/.datasets/'
namedataset = 'cellcaltech0001'
sub_folder = ''
folders_images = 'images'
folders_labels = 'labels'
base_folder = os.path.join(pathdataset, namedataset)

data = CTECHDataset(
    base_folder,
    sub_folder,
    count=10,
    transform=transforms.Compose([
        mtrans.ToResize((500, 500), resize_mode='crop'),
        mtrans.RandomCrop((255, 255),
                          limit=50,
                          padding_mode=cv2.BORDER_CONSTANT),
        #mtrans.ToResizeUNetFoV(388, cv2.BORDER_REFLECT_101),
        mtrans.ToTensor(),
        mtrans.ToNormalization(),
    ]))

# sample = data[0]
# for k,v in sample.items():
#     print( k, ':', v.shape, v.min(), v.max() )
# print('\n')
# #assert(False)

dataloader = DataLoader(data, batch_size=3, shuffle=False, num_workers=1)
コード例 #16
0
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

root= '~/.datasets'
pathname = os.path.expanduser(root)
name_dataset=FactoryDataset.afew

dataloader = Dataset(
    data=FactoryDataset.factory(pathname=pathname, 
        name=name_dataset, 
        subset=FactoryDataset.validation, 
        download=True ),
    num_channels=3,
    transform=transforms.Compose([
        mtrans.ToResize( (64, 64), resize_mode='square', padding_mode=cv2.BORDER_REFLECT101 ),
        mtrans.ToTensor(),
        mtrans.ToNormalization(),
        ])
    )

print(dataloader.labels)
print(dataloader.classes)
#print(dataloader.data.classes)
#print(dataloader.data.class_to_idx)

print( dataloader[0]['label'].shape, dataloader[0]['image'].shape )
print( len(dataloader) )

plt.figure( figsize=(16,16))
view.visualizatedataset(dataloader, num=100, imsize=(64,64,3) )
コード例 #17
0
def test_ellipse():

    data = SyntethicCircleDataset(
        count=300,
        generate=SyntethicCircleDataset.generate_image_mask_and_weight,
        imsize=(512, 612),
        sigma=0.01,
        bdraw_grid=True,
        transform=transforms.Compose([

            ## resize and crop
            mtrans.ToResize((400, 400),
                            resize_mode='square',
                            padding_mode=cv2.BORDER_REFLECT_101),
            #mtrans.CenterCrop( (200,200) ),
            #mtrans.RandomCrop( (255,255), limit=50, padding_mode=cv2.BORDER_REFLECT_101  ),
            #mtrans.ToResizeUNetFoV(388, cv2.BORDER_REFLECT_101),

            ## color
            mtrans.ToRandomChoiceTransform([
                mtrans.RandomSaturation(),
                mtrans.RandomHueSaturationShift(),
                mtrans.RandomHueSaturation(),
                #mtrans.RandomRGBShift(),
                #mtrans.ToNegative(),
                #mtrans.RandomRGBPermutation(),
                #mtrans.ToRandomTransform( mtrans.ToGrayscale(), prob=0.5 ),
                mtrans.ToGrayscale(),
            ]),

            ## blur
            #mtrans.ToRandomTransform( mtrans.ToLinealMotionBlur( lmax=1 ), prob=0.5 ),
            #mtrans.ToRandomTransform( mtrans.ToMotionBlur( ), prob=0.5 ),
            mtrans.ToRandomTransform(mtrans.ToGaussianBlur(), prob=0.75),

            ## geometrical
            #mtrans.ToRandomTransform( mtrans.HFlip(), prob=0.5 )
            #mtrans.ToRandomTransform( mtrans.VFlip(), prob=0.5 )
            mtrans.RandomScale(factor=0.2, padding_mode=cv2.BORDER_REFLECT101),
            #mtrans.RandomGeometricalTransform( angle=360, translation=0.2, warp=0.02, padding_mode=cv2.BORDER_REFLECT101),
            #mtrans.RandomElasticDistort( size_grid=50, padding_mode=cv2.BORDER_REFLECT101 ),

            ## tensor
            mtrans.ToTensor(),
            mtrans.RandomElasticTensorDistort(size_grid=10, deform=0.05),

            ## normalization
            mtrans.ToNormalization(),
            #mtrans.ToWhiteNormalization(),
            #mtrans.ToMeanNormalization(
            #    mean=[0.485, 0.456, 0.406],
            #    std=[0.229, 0.224, 0.225]
            #    ),
        ]))

    dataloader = DataLoader(data, batch_size=3, shuffle=True, num_workers=1)

    label_batched = []
    for i_batch, sample_batched in enumerate(dataloader):
        print(i_batch, sample_batched['image'].size(),
              sample_batched['label'].size(), sample_batched['weight'].size())

        image = sample_batched['image']
        label = sample_batched['label']
        weight = sample_batched['weight']

        print(torch.min(image), torch.max(image), image.shape)
        print(torch.min(label), torch.max(label), image.shape)
        print(torch.min(weight), torch.max(weight), image.shape)

        print(image.shape)
        print(np.unique(label))
        print(image.min(), image.max())

        image = image.permute(2, 3, 1, 0)[:, :, :, 0].squeeze()
        label = label.permute(2, 3, 1, 0)[:, :, :, 0].squeeze()
        weight = weight.permute(2, 3, 1, 0)[:, :, :, 0].squeeze()

        plt.figure(figsize=(16, 4))
        plt.subplot(131)
        plt.imshow(image)
        plt.title('image + [grid]')
        plt.axis('off')
        plt.ioff()

        plt.subplot(132)
        plt.imshow(label)
        plt.title('gt')
        plt.axis('off')
        plt.ioff()

        plt.subplot(133)
        plt.imshow(weight)
        plt.title('weight map')
        plt.axis('off')
        plt.ioff()

        #print('save figure ...')
        #plt.savefig('../out/image_{}.png'.format(i_batch) )

        plt.show()

        # observe 4th batch and stop.
        if i_batch == 1:
            break