def get_transformations(p):
    """ Return transformations for training and evaluationg """
    from data import custom_transforms as tr

    # Training transformations
    if p['train_db_name'] == 'NYUD':
        # Horizontal flips with probability of 0.5
        transforms_tr = [tr.RandomHorizontalFlip()]
        
        # Rotations and scaling
        transforms_tr.extend([tr.ScaleNRotate(rots=[0], scales=[1.0, 1.2, 1.5],
                                              flagvals={x: p.ALL_TASKS.FLAGVALS[x] for x in p.ALL_TASKS.FLAGVALS})])

    elif p['train_db_name'] == 'PASCALContext':
        # Horizontal flips with probability of 0.5
        transforms_tr = [tr.RandomHorizontalFlip()]
        
        # Rotations and scaling
        transforms_tr.extend([tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25),
                                              flagvals={x: p.ALL_TASKS.FLAGVALS[x] for x in p.ALL_TASKS.FLAGVALS})])

    else:
        raise ValueError('Invalid train db name'.format(p['train_db_name']))


    # Fixed Resize to input resolution
    transforms_tr.extend([tr.FixedResize(resolutions={x: tuple(p.TRAIN.SCALE) for x in p.ALL_TASKS.FLAGVALS},
                                         flagvals={x: p.ALL_TASKS.FLAGVALS[x] for x in p.ALL_TASKS.FLAGVALS})])
    transforms_tr.extend([tr.AddIgnoreRegions(), tr.ToTensor(),
                          tr.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    transforms_tr = transforms.Compose(transforms_tr)

    
    # Testing (during training transforms)
    transforms_ts = []
    transforms_ts.extend([tr.FixedResize(resolutions={x: tuple(p.TEST.SCALE) for x in p.TASKS.FLAGVALS},
                                         flagvals={x: p.TASKS.FLAGVALS[x] for x in p.TASKS.FLAGVALS})])
    transforms_ts.extend([tr.AddIgnoreRegions(), tr.ToTensor(),
                          tr.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    transforms_ts = transforms.Compose(transforms_ts)

    return transforms_tr, transforms_ts
def test_mt():
    import torch
    import data.custom_transforms as tr
    import matplotlib.pyplot as plt
    from torchvision import transforms
    transform = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.ScaleNRotate(rots=(-2, 2),
                        scales=(.75, 1.25),
                        flagvals={
                            'image': cv2.INTER_CUBIC,
                            'edge': cv2.INTER_NEAREST,
                            'semseg': cv2.INTER_NEAREST,
                            'normals': cv2.INTER_LINEAR,
                            'depth': cv2.INTER_LINEAR
                        }),
        tr.FixedResize(resolutions={
            'image': (512, 512),
            'edge': (512, 512),
            'semseg': (512, 512),
            'normals': (512, 512),
            'depth': (512, 512)
        },
                       flagvals={
                           'image': cv2.INTER_CUBIC,
                           'edge': cv2.INTER_NEAREST,
                           'semseg': cv2.INTER_NEAREST,
                           'normals': cv2.INTER_LINEAR,
                           'depth': cv2.INTER_LINEAR
                       }),
        tr.AddIgnoreRegions(),
        tr.ToTensor()
    ])
    dataset = NYUD_MT(split='train',
                      transform=transform,
                      retname=True,
                      do_edge=True,
                      do_semseg=True,
                      do_normals=True,
                      do_depth=True)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=5,
                                             shuffle=False,
                                             num_workers=5)

    for i, sample in enumerate(dataloader):
        print(i)
        for j in range(sample['image'].shape[0]):
            f, ax_arr = plt.subplots(5)
            for k in range(len(ax_arr)):
                ax_arr[k].cla()
            ax_arr[0].imshow(np.transpose(sample['image'][j], (1, 2, 0)))
            ax_arr[1].imshow(sample['edge'][j, 0])
            ax_arr[2].imshow(sample['semseg'][j, 0] / 40)
            ax_arr[3].imshow(np.transpose(sample['normals'][j], (1, 2, 0)))
            max_depth = torch.max(
                sample['depth'][j, 0][sample['depth'][j, 0] != 255]).item()
            ax_arr[4].imshow(
                sample['depth'][j, 0] /
                max_depth)  # Not ideal. Better is to show inverse depth.

            plt.show()
        break
Example #3
0
def test_all():
    import matplotlib.pyplot as plt
    import torch
    import data.custom_transforms as tr
    from torchvision import transforms
    from utils.custom_collate import collate_mil

    transform = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.ScaleNRotate(rots=(-90, 90),
                        scales=(1., 1.),
                        flagvals={
                            'image': cv2.INTER_CUBIC,
                            'edge': cv2.INTER_NEAREST,
                            'semseg': cv2.INTER_NEAREST,
                            'human_parts': cv2.INTER_NEAREST,
                            'normals': cv2.INTER_CUBIC,
                            'sal': cv2.INTER_NEAREST
                        }),
        tr.FixedResize(resolutions={
            'image': (512, 512),
            'edge': (512, 512),
            'semseg': (512, 512),
            'human_parts': (512, 512),
            'normals': (512, 512),
            'sal': (512, 512)
        },
                       flagvals={
                           'image': cv2.INTER_CUBIC,
                           'edge': cv2.INTER_NEAREST,
                           'semseg': cv2.INTER_NEAREST,
                           'human_parts': cv2.INTER_NEAREST,
                           'normals': cv2.INTER_CUBIC,
                           'sal': cv2.INTER_NEAREST
                       }),
        tr.AddIgnoreRegions(),
        tr.ToTensor()
    ])
    dataset = PASCALContext(split='train',
                            transform=transform,
                            retname=True,
                            do_edge=True,
                            do_semseg=True,
                            do_human_parts=True,
                            do_normals=True,
                            do_sal=True)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=2,
                                             shuffle=False,
                                             num_workers=0)

    for i, sample in enumerate(dataloader):
        print(i)
        for j in range(sample['image'].shape[0]):
            f, ax_arr = plt.subplots(2, 3)

            for k in range(len(ax_arr)):
                for l in range(len(ax_arr[k])):
                    ax_arr[k][l].cla()

            ax_arr[0][0].imshow(np.transpose(sample['image'][j], (1, 2, 0)))
            ax_arr[0][1].imshow(
                np.transpose(sample['edge'][j], (1, 2, 0))[:, :, 0])
            ax_arr[0][2].imshow(
                np.transpose(sample['semseg'][j], (1, 2, 0))[:, :, 0] / 20.)
            ax_arr[1][0].imshow(
                np.transpose(sample['human_parts'][j],
                             (1, 2, 0))[:, :, 0] / 6.)
            ax_arr[1][1].imshow(np.transpose(sample['normals'][j], (1, 2, 0)))
            ax_arr[1][2].imshow(
                np.transpose(sample['sal'][j], (1, 2, 0))[:, :, 0])

            plt.show()
        break