def get_transformations(p): """ Return transformations for training and evaluationg """ from data import custom_transforms as tr # Training transformations if p['train_db_name'] == 'NYUD': # Horizontal flips with probability of 0.5 transforms_tr = [tr.RandomHorizontalFlip()] # Rotations and scaling transforms_tr.extend([tr.ScaleNRotate(rots=[0], scales=[1.0, 1.2, 1.5], flagvals={x: p.ALL_TASKS.FLAGVALS[x] for x in p.ALL_TASKS.FLAGVALS})]) elif p['train_db_name'] == 'PASCALContext': # Horizontal flips with probability of 0.5 transforms_tr = [tr.RandomHorizontalFlip()] # Rotations and scaling transforms_tr.extend([tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25), flagvals={x: p.ALL_TASKS.FLAGVALS[x] for x in p.ALL_TASKS.FLAGVALS})]) else: raise ValueError('Invalid train db name'.format(p['train_db_name'])) # Fixed Resize to input resolution transforms_tr.extend([tr.FixedResize(resolutions={x: tuple(p.TRAIN.SCALE) for x in p.ALL_TASKS.FLAGVALS}, flagvals={x: p.ALL_TASKS.FLAGVALS[x] for x in p.ALL_TASKS.FLAGVALS})]) transforms_tr.extend([tr.AddIgnoreRegions(), tr.ToTensor(), tr.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) transforms_tr = transforms.Compose(transforms_tr) # Testing (during training transforms) transforms_ts = [] transforms_ts.extend([tr.FixedResize(resolutions={x: tuple(p.TEST.SCALE) for x in p.TASKS.FLAGVALS}, flagvals={x: p.TASKS.FLAGVALS[x] for x in p.TASKS.FLAGVALS})]) transforms_ts.extend([tr.AddIgnoreRegions(), tr.ToTensor(), tr.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) transforms_ts = transforms.Compose(transforms_ts) return transforms_tr, transforms_ts
def test_mt(): import torch import data.custom_transforms as tr import matplotlib.pyplot as plt from torchvision import transforms transform = transforms.Compose([ tr.RandomHorizontalFlip(), tr.ScaleNRotate(rots=(-2, 2), scales=(.75, 1.25), flagvals={ 'image': cv2.INTER_CUBIC, 'edge': cv2.INTER_NEAREST, 'semseg': cv2.INTER_NEAREST, 'normals': cv2.INTER_LINEAR, 'depth': cv2.INTER_LINEAR }), tr.FixedResize(resolutions={ 'image': (512, 512), 'edge': (512, 512), 'semseg': (512, 512), 'normals': (512, 512), 'depth': (512, 512) }, flagvals={ 'image': cv2.INTER_CUBIC, 'edge': cv2.INTER_NEAREST, 'semseg': cv2.INTER_NEAREST, 'normals': cv2.INTER_LINEAR, 'depth': cv2.INTER_LINEAR }), tr.AddIgnoreRegions(), tr.ToTensor() ]) dataset = NYUD_MT(split='train', transform=transform, retname=True, do_edge=True, do_semseg=True, do_normals=True, do_depth=True) dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=False, num_workers=5) for i, sample in enumerate(dataloader): print(i) for j in range(sample['image'].shape[0]): f, ax_arr = plt.subplots(5) for k in range(len(ax_arr)): ax_arr[k].cla() ax_arr[0].imshow(np.transpose(sample['image'][j], (1, 2, 0))) ax_arr[1].imshow(sample['edge'][j, 0]) ax_arr[2].imshow(sample['semseg'][j, 0] / 40) ax_arr[3].imshow(np.transpose(sample['normals'][j], (1, 2, 0))) max_depth = torch.max( sample['depth'][j, 0][sample['depth'][j, 0] != 255]).item() ax_arr[4].imshow( sample['depth'][j, 0] / max_depth) # Not ideal. Better is to show inverse depth. plt.show() break
def test_all(): import matplotlib.pyplot as plt import torch import data.custom_transforms as tr from torchvision import transforms from utils.custom_collate import collate_mil transform = transforms.Compose([ tr.RandomHorizontalFlip(), tr.ScaleNRotate(rots=(-90, 90), scales=(1., 1.), flagvals={ 'image': cv2.INTER_CUBIC, 'edge': cv2.INTER_NEAREST, 'semseg': cv2.INTER_NEAREST, 'human_parts': cv2.INTER_NEAREST, 'normals': cv2.INTER_CUBIC, 'sal': cv2.INTER_NEAREST }), tr.FixedResize(resolutions={ 'image': (512, 512), 'edge': (512, 512), 'semseg': (512, 512), 'human_parts': (512, 512), 'normals': (512, 512), 'sal': (512, 512) }, flagvals={ 'image': cv2.INTER_CUBIC, 'edge': cv2.INTER_NEAREST, 'semseg': cv2.INTER_NEAREST, 'human_parts': cv2.INTER_NEAREST, 'normals': cv2.INTER_CUBIC, 'sal': cv2.INTER_NEAREST }), tr.AddIgnoreRegions(), tr.ToTensor() ]) dataset = PASCALContext(split='train', transform=transform, retname=True, do_edge=True, do_semseg=True, do_human_parts=True, do_normals=True, do_sal=True) dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0) for i, sample in enumerate(dataloader): print(i) for j in range(sample['image'].shape[0]): f, ax_arr = plt.subplots(2, 3) for k in range(len(ax_arr)): for l in range(len(ax_arr[k])): ax_arr[k][l].cla() ax_arr[0][0].imshow(np.transpose(sample['image'][j], (1, 2, 0))) ax_arr[0][1].imshow( np.transpose(sample['edge'][j], (1, 2, 0))[:, :, 0]) ax_arr[0][2].imshow( np.transpose(sample['semseg'][j], (1, 2, 0))[:, :, 0] / 20.) ax_arr[1][0].imshow( np.transpose(sample['human_parts'][j], (1, 2, 0))[:, :, 0] / 6.) ax_arr[1][1].imshow(np.transpose(sample['normals'][j], (1, 2, 0))) ax_arr[1][2].imshow( np.transpose(sample['sal'][j], (1, 2, 0))[:, :, 0]) plt.show() break